In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 28)
        self.fc2 = nn.Linear(28, 28)
        self.fc3 = nn.Linear(28, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class SupervisorNN(nn.Module):
    def __init__(self, activation_len):
        super(SupervisorNN, self).__init__()
        self.activation_len = activation_len
        self.fc1 = nn.Linear(activation_len, activation_len * 10)
        self.fc2 = nn.Linear(activation_len * 10, activation_len * 10)
        self.fc3 = nn.Linear(activation_len * 10, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_model(model, train_loader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    model.to(device)
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            

def get_train_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    return train_loader

def extract_activations_and_correctness(model, data_loader):
    model.eval()
    model.to(device)
    all_activations = []
    all_correctness = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)

            h1_binary = (h1 > 0).float()
            h2_binary = (h2 > 0).float()


            predictions = torch.argmax(output, dim=1)
            is_correct = (predictions == target).float()
            activation_pattern = torch.cat([h1_binary, h2_binary], dim=1)
            all_activations.append(activation_pattern.cpu())
            all_correctness.append(is_correct.cpu())
    
    return torch.cat(all_activations).numpy(), torch.cat(all_correctness).numpy()

def train_supervisor(supervisor, activations, correctness_labels, epochs=10, batch_size=256):
    X = torch.FloatTensor(activations)
    y = torch.LongTensor(1 - correctness_labels.astype(int))

    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    supervisor.to(device)
    supervisor.train()
    optimizer = torch.optim.Adam(supervisor.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        for batch_x, batch_y in data_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            output = supervisor(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            
model = SimpleNN()
train_loader = get_train_loader()
train_model(model, train_loader)
activations, correctness_labels = extract_activations_and_correctness(model, train_loader)
supervisor = SupervisorNN(activations.shape[1])
train_supervisor(supervisor, activations, correctness_labels, epochs=20)

Epoch 0, Loss: 0.1645
Epoch 2, Loss: 0.1482
Epoch 4, Loss: 0.1428
Epoch 6, Loss: 0.1389
Epoch 8, Loss: 0.1354
Epoch 10, Loss: 0.1322
Epoch 12, Loss: 0.1272
Epoch 14, Loss: 0.1236
Epoch 16, Loss: 0.1206
Epoch 18, Loss: 0.1158


In [5]:
def get_test_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    return test_loader

test_loader = get_test_loader()
test_activations, test_correctness_labels = extract_activations_and_correctness(model, test_loader)

supervisor.eval()
supervisor_outputs = []
with torch.no_grad():
    for i, activation in enumerate(test_activations):  # Test first 10 samples
        activation_tensor = torch.FloatTensor(activation).unsqueeze(0).to(device)
        supervisor_output = supervisor(activation_tensor)
        supervisor_outputs.append(supervisor_output[0])


confident_outputs_mask = np.array([output[0] > 1.8 for output in supervisor_outputs])
confident_outputs_test_labels = test_correctness_labels[confident_outputs_mask]
outputs_coverage = np.mean(confident_outputs_mask)
print(f"Coverage of confident NAPs: {outputs_coverage:.3f}")
print(f"Accuracy on confident predictions (how often model gets answer right): {np.sum(confident_outputs_test_labels) / len(confident_outputs_test_labels):.4f}")
print(f"Model's true test accuracy (how often model gets answer right): {np.sum(test_correctness_labels) / len(test_correctness_labels):.4f}")

# 0.385 coverage for 99% accuracy with layer size 10, 50 epochs

Coverage of confident NAPs: 0.683
Accuracy on confident predictions (how often model gets answer right): 0.9876
Model's true test accuracy (how often model gets answer right): 0.9533


In [44]:
def extract_output_logits(model, data_loader):
    model.eval()
    model.to(device)
    outputs = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)

            outputs.append(output)
    
    return torch.cat(outputs).numpy()

def get_test_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    return test_loader

test_loader = get_test_loader()
test_outputs = extract_output_logits(model, test_loader)

confident_outputs_mask = np.array([max(output) > 6.79 for output in test_outputs]) #5.9
confident_outputs_test_labels = test_correctness_labels[confident_outputs_mask]
outputs_coverage = np.mean(confident_outputs_mask)
print(f"Baseline coverage of confident NAPs: {outputs_coverage:.3f}")
print(f"Baseline Accuracy on confident predictions (how often model gets answer right): {np.sum(confident_outputs_test_labels) / len(confident_outputs_test_labels):.4f}")
print(f"Baseline Model's true test accuracy (how often model gets answer right): {np.sum(test_correctness_labels) / len(test_correctness_labels):.4f}")

Baseline coverage of confident NAPs: 0.719
Baseline Accuracy on confident predictions (how often model gets answer right): 0.9900
Baseline Model's true test accuracy (how often model gets answer right): 0.9533
