In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
import torch.nn.functional as F
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 28)
        self.fc2 = nn.Linear(28, 28)
        self.fc3 = nn.Linear(28, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_model(model, train_loader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    model.to(device)
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            

def get_train_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    return train_loader

def extract_activations_and_losses(model, data_loader):
    model.eval()
    model.to(device)
    all_activations = []
    all_losses = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)

            h1_binary = (h1 > 0).float()
            h2_binary = (h2 > 0).float()

            losses = F.cross_entropy(output, target, reduction='none')
            activation_pattern = torch.cat([h1_binary, h2_binary], dim=1)
            all_activations.append(activation_pattern.cpu())
            all_losses.append(losses)
    
    return torch.cat(all_activations).numpy(), torch.cat(all_losses).numpy()

model = SimpleNN()
train_loader = get_train_loader()
train_model(model, train_loader)
activations, losses = extract_activations_and_losses(model, train_loader)

class Node:
    def __init__(self, data=None, split_criteria=None, left=None, right=None):
        self.data = data
        self.split_criteria = split_criteria
        self.left = left
        self.right = right

def find_best_split(objects):
    n_objects = objects.shape[0]
    if n_objects < 2:
        return None
    ones = np.sum(objects, axis=0)
    zeros = n_objects - ones
    entropy_reduction = np.minimum(zeros, ones)
    max_entropy_reduction = np.max(entropy_reduction)
    if max_entropy_reduction == 0:
        return None
    best_pos = np.argmax(entropy_reduction)
    
    return best_pos

def build_tree(objects, current_criteria=None):
    print(f"Current leaf has {len(objects)} objects")
    if current_criteria is None:
        current_criteria = []
    if objects.shape[0] <= 1:
        return Node(data=objects, split_criteria=current_criteria[:])
    split_pos = find_best_split(objects)
    if split_pos is None:
        return Node(data=objects, split_criteria=current_criteria[:])
    left_mask = objects[:, split_pos] == 0
    right_mask = ~left_mask
    left_objects = objects[left_mask]
    right_objects = objects[right_mask]
    if left_objects.shape[0] == 0 or right_objects.shape[0] == 0:
        return Node(data=objects, split_criteria=current_criteria[:])
    # The right branch's criteria includes the new split rule
    right_criteria = current_criteria + [[split_pos, 1]]
    return Node(
        split_criteria=current_criteria[:],
        left=build_tree(left_objects, current_criteria),
        right=build_tree(right_objects, right_criteria)
    )

objects_list = activations
objects_np = np.array(objects_list)
tree = build_tree(objects_np)

Current leaf has 60000 objects
Current leaf has 29835 objects
Current leaf has 14575 objects
Current leaf has 7337 objects
Current leaf has 3451 objects
Current leaf has 1707 objects
Current leaf has 896 objects
Current leaf has 372 objects
Current leaf has 175 objects
Current leaf has 107 objects
Current leaf has 80 objects
Current leaf has 69 objects
Current leaf has 65 objects
Current leaf has 1 objects
Current leaf has 64 objects
Current leaf has 63 objects
Current leaf has 1 objects
Current leaf has 4 objects
Current leaf has 11 objects
Current leaf has 7 objects
Current leaf has 1 objects
Current leaf has 6 objects
Current leaf has 4 objects
Current leaf has 1 objects
Current leaf has 3 objects
Current leaf has 27 objects
Current leaf has 14 objects
Current leaf has 12 objects
Current leaf has 10 objects
Current leaf has 2 objects
Current leaf has 2 objects
Current leaf has 13 objects
Current leaf has 3 objects
Current leaf has 2 objects
Current leaf has 1 objects
Current leaf ha

In [2]:
import numpy as np
def extract_activations_and_correctness(model, data_loader):
    model.eval()
    model.to(device)
    all_activations = []
    all_correctness = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)

            h1_binary = (h1 > 0).float()
            h2_binary = (h2 > 0).float()


            predictions = torch.argmax(output, dim=1)
            is_correct = (predictions == target).float()
            activation_pattern = torch.cat([h1_binary, h2_binary], dim=1)
            all_activations.append(activation_pattern.cpu())
            all_correctness.append(is_correct.cpu())
    
    return torch.cat(all_activations).numpy(), torch.cat(all_correctness).numpy()

def get_test_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    return test_loader

def find_leaf(node, datapoint):
    if node.data is not None:
        return node
    split_pos, split_value = node.right.split_criteria[-1]
    
    if datapoint[split_pos] == split_value:
        return find_leaf(node.right, datapoint)
    else:
        return find_leaf(node.left, datapoint)
    
def find_leaf_loss(validation_leaf_prediction, test_feature_vectors, test_losses):
    leaf_data = validation_leaf_prediction.data
    index = [np.where((test_feature_vectors == row).all(axis=1))[0][0] for row in leaf_data][0]
    print(test_losses[index])
    return test_losses[index]

test_loader = get_test_loader()
test_activations, test_correctness_labels = extract_activations_and_correctness(model, test_loader)
test_activations, test_losses = extract_activations_and_losses(model, test_loader)

validation_leaf_predictions = [find_leaf(tree, validation_feature_vector) for validation_feature_vector in test_activations]
validation_loss_predictions = [find_leaf_loss(validation_leaf_prediction, activations, losses) for validation_leaf_prediction in validation_leaf_predictions]
low_loss_threshold = 0.0735
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = test_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(test_activations)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(test_correctness_labels) / len(test_correctness_labels):.4f}")


0.009247447
0.0003644756
0.01302932
0.00017045476
0.013808672
0.05258625
0.006762002
0.4050602
0.1792875
0.09549919
0.00085484196
0.04712821
0.005974885
0.00023040501
0.006190767
0.028020138
0.07237737
0.003757442
0.87319255
0.020754894
0.16522552
0.0076214965
0.037884526
0.0010434904
0.009923411
1.4543428e-05
0.43935034
0.00854467
3.9695904e-05
0.003711599
9.560128e-05
0.6709469
0.008410984
0.5392699
0.00048851955
0.019026162
0.84955233
0.006678522
0.12564288
0.10685033
0.010341983
0.008087388
0.0010071688
0.0011636398
0.43063718
0.027748976
0.17748372
0.22975929
0.0019556223
0.012284569
0.27472064
0.002466971
0.053726688
0.0025030018
1.7595943
0.008711312
0.008789541
0.00093475985
0.0026041903
0.0056960178
0.011195932
0.062836155
0.17311484
0.0011890016
0.0019481267
0.0029240968
0.0056382907
0.006990973
0.0019539567
0.00056870497
0.00919654
0.00028856404
0.017988749
0.511429
0.0031445601
0.003670148
0.031566534
0.080497
0.137602
0.0033389553
0.04331549
0.0032349895
7.4503034e-05
0.02

TypeError: '<' not supported between instances of 'list' and 'float'

In [13]:
low_loss_threshold = 0.00002
validation_loss_predictions = np.array(validation_loss_predictions)
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = test_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(test_activations)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(test_correctness_labels) / len(test_correctness_labels):.4f}")

Coverage of confident predictions: 0.009 (91/10000 validation samples)
Accuracy on confident predictions: 1.0000
Model's true validation accuracy: 0.9537
