In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from sklearn.tree import DecisionTreeClassifier


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 28)
        self.fc2 = nn.Linear(28, 28)
        self.fc3 = nn.Linear(28, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_model(model, train_loader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    model.to(device)
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            

def get_train_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    return train_loader

def extract_activations_and_correctness(model, data_loader):
    model.eval()
    model.to(device)
    all_activations = []
    all_correctness = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)

            h1_binary = (h1 > 0).float()
            h2_binary = (h2 > 0).float()


            predictions = torch.argmax(output, dim=1)
            is_correct = (predictions == target).float()
            activation_pattern = torch.cat([h1_binary, h2_binary], dim=1)
            all_activations.append(activation_pattern.cpu())
            all_correctness.append(is_correct.cpu())
    
    return torch.cat(all_activations).numpy(), torch.cat(all_correctness).numpy()

model = SimpleNN()
train_loader = get_train_loader()
train_model(model, train_loader)
activations, correctness_labels = extract_activations_and_correctness(model, train_loader)
tree = DecisionTreeClassifier(
    criterion='entropy',
    max_depth=1000, # lower = more accurate but less coverage
    min_samples_leaf=80, # higher = more accurate but less coverage
    random_state=42
)
tree.fit(activations, correctness_labels)

leaf_indices = tree.apply(activations) 
unique_leaves, leaf_counts = np.unique(leaf_indices, return_counts=True)
confident_leaves = []
for leaf_idx in unique_leaves:
    samples_in_leaf = (leaf_indices == leaf_idx)
    leaf_correctness = correctness_labels[samples_in_leaf]
    if np.all(leaf_correctness == 1):
            confident_leaves.append(leaf_idx)
confident_leaves = set(confident_leaves)

In [2]:
activations, correctness_labels = extract_activations_and_correctness(model, train_loader)
tree = DecisionTreeClassifier(
    criterion='entropy',
    max_depth=1000, # lower = more accurate but less coverage
    min_samples_leaf=465, # higher = more accurate but less coverage
    random_state=42
)
tree.fit(activations, correctness_labels)

leaf_indices = tree.apply(activations) 
unique_leaves, leaf_counts = np.unique(leaf_indices, return_counts=True)
confident_leaves = []
for leaf_idx in unique_leaves:
    samples_in_leaf = (leaf_indices == leaf_idx)
    leaf_correctness = correctness_labels[samples_in_leaf]
    if np.mean(leaf_correctness) > 0.999:
            confident_leaves.append(leaf_idx)
confident_leaves = set(confident_leaves)

In [3]:
def get_test_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    return test_loader

test_loader = get_test_loader()
test_activations, test_correctness_labels = extract_activations_and_correctness(model, test_loader)
test_leaf_indices = tree.apply(test_activations)
confident_mask = np.array([leaf_idx in confident_leaves for leaf_idx in test_leaf_indices])
confident_test_labels = test_correctness_labels[confident_mask]
print(f"Coverage of confident NAPs: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(test_activations)} test samples)")
print(f"Accuracy on confident predictions: { np.mean(confident_test_labels):.4f}")
print(f"Model's true test accuracy: {np.sum(test_correctness_labels) / len(test_correctness_labels):.4f}")

Coverage of confident NAPs: 0.043 (425/10000 test samples)
Accuracy on confident predictions: 0.9929
Model's true test accuracy: 0.9520


In [4]:
def extract_output_logits(model, data_loader):
    model.eval()
    model.to(device)
    outputs = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)

            outputs.append(output)
    
    return torch.cat(outputs).numpy()

def get_test_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    return test_loader

test_loader = get_test_loader()
test_outputs = extract_output_logits(model, test_loader)

confident_outputs_mask = np.array([max(output) > 10.2 for output in test_outputs])
confident_outputs_test_labels = test_correctness_labels[confident_outputs_mask]
outputs_coverage = np.mean(confident_outputs_mask)
print(f"Baseline coverage of confident NAPs: {outputs_coverage:.3f}")
print(f"Baseline Accuracy on confident predictions (how often model gets answer right): {np.sum(confident_outputs_test_labels) / len(confident_outputs_test_labels):.4f}")
print(f"Baseline Model's true test accuracy (how often model gets answer right): {np.sum(test_correctness_labels) / len(test_correctness_labels):.4f}")

Baseline coverage of confident NAPs: 0.058
Baseline Accuracy on confident predictions (how often model gets answer right): 1.0000
Baseline Model's true test accuracy (how often model gets answer right): 0.9520
