In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import time
import heapq
import itertools

# --- Configuration ---
# Set the number of random samples to use for training, calibration, and testing.
# Calibration data is used to build the NAP tree.
# Test data is used for the final evaluation.
# Calibration and Test sets are disjoint.
NUM_TRAIN_SAMPLES = 75
NUM_CALIBRATION_SAMPLES = 8400
NUM_TEST_SAMPLES = 1540


# --- Model and Data Functions ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 28)
        self.fc2 = nn.Linear(28, 28)
        self.fc3 = nn.Linear(28, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_subset_loaders(num_train, num_calibration, num_test, batch_size=256):
    """
    Creates data loaders for training, calibration, and testing using random subsets of the MNIST dataset.
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}

    # Load full datasets
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

    # Create random subset for training
    if num_train > len(train_dataset):
        raise ValueError(f"Requested {num_train} train samples, but only {len(train_dataset)} are available.")
    train_indices = np.random.choice(len(train_dataset), num_train, replace=False)
    train_subset = Subset(train_dataset, train_indices)
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, **loader_kwargs)

    # Create disjoint random subsets for calibration and testing from the test set
    if num_calibration + num_test > len(test_dataset):
        raise ValueError(f"The sum of calibration ({num_calibration}) and test ({num_test}) samples "
                         f"cannot exceed the total test samples ({len(test_dataset)}).")
        
    test_indices = np.arange(len(test_dataset))
    np.random.shuffle(test_indices)
    
    calibration_indices = test_indices[:num_calibration]
    test_indices_final = test_indices[num_calibration : num_calibration + num_test]

    calibration_subset = Subset(test_dataset, calibration_indices)
    test_subset = Subset(test_dataset, test_indices_final)

    calibration_loader = DataLoader(calibration_subset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, **loader_kwargs)

    print(f"Using {len(train_subset)} samples for training.")
    print(f"Using {len(calibration_subset)} samples for calibration (NAP tree construction).")
    print(f"Using {len(test_subset)} samples for final evaluation.")

    return train_loader, calibration_loader, test_loader


def train_model(model, train_loader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    model.to(device)
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

def extract_activations_and_losses(model, data_loader):
    model.eval()
    model.to(device)
    all_activations = []
    all_losses = []
    criterion = nn.CrossEntropyLoss(reduction='none')

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)
            
            _, predicted = torch.max(output, 1)
            correct_mask = (predicted == target)

            if not correct_mask.any():
                continue

            h1_correct = h1[correct_mask]
            h2_correct = h2[correct_mask]
            
            h1_binary = (h1_correct > 0).float()
            h2_binary = (h2_correct > 0).float()
            nap = torch.cat([h1_binary, h2_binary], dim=1)
            
            all_activations.append(nap.cpu())
            all_losses.append(criterion(output[correct_mask], target[correct_mask]).cpu())
    
    return torch.cat(all_activations).numpy(), torch.cat(all_losses).numpy()

class NapNode:
    id_iter = itertools.count()

    def __init__(self, indices, parent=None, required=None, forbidden=None):
        self.id = next(NapNode.id_iter)
        self.indices = indices
        self.parent = parent
        
        self.required = parent.required[:] if parent else []
        if required is not None: self.required.append(required)
        
        self.forbidden = parent.forbidden[:] if parent else []
        if forbidden is not None: self.forbidden.append(forbidden)
        
        self.children = []
        self.is_leaf = True
        self.variance = -1.0
        self.sum_of_activations = None
        self.sample_count = 0

    def calculate_variance(self, losses):
        if len(self.indices) < 2:
            self.variance = 0.0
        else:
            self.variance = np.var(losses[self.indices])
        return self.variance

def get_all_test_data(model, test_loader, device):
    model.eval()
    model.to(device)
    all_activations, all_predicted, all_labels, all_losses = [], [], [], []
    criterion = nn.CrossEntropyLoss(reduction='none')

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            x = data.view(-1, 28 * 28)
            h1 = F.relu(model.fc1(x))
            h2 = F.relu(model.fc2(h1))
            output = model.fc3(h2)
            
            _, predicted = torch.max(output, 1)
            losses = criterion(output, target)

            h1_binary = (h1 > 0).byte()
            h2_binary = (h2 > 0).byte()
            activations = torch.cat([h1_binary, h2_binary], dim=1)
            
            all_activations.append(activations.cpu().numpy())
            all_predicted.append(predicted.cpu().numpy())
            all_labels.append(target.cpu().numpy())
            all_losses.append(losses.cpu().numpy())
            
    return (np.vstack(all_activations), 
            np.concatenate(all_predicted), 
            np.concatenate(all_labels), 
            np.concatenate(all_losses))

def find_covering_leaf_node(activation, root_node):
    current_node = root_node
    while not current_node.is_leaf:
        if not current_node.children:
            break
            
        child1_constraints = set(current_node.children[0].required) | set(current_node.children[0].forbidden)
        parent_constraints = set(current_node.required) | set(current_node.forbidden)
        split_neuron_idx = (child1_constraints - parent_constraints).pop()

        if activation[split_neuron_idx] == 1:
            if split_neuron_idx in current_node.children[0].required:
                current_node = current_node.children[0]
            else:
                current_node = current_node.children[1]
        else:
            if split_neuron_idx in current_node.children[0].forbidden:
                current_node = current_node.children[0]
            else:
                current_node = current_node.children[1]
    return current_node

def nap_node_to_string(node):
    if not node: return "Inactive NAP"
    req_str = ','.join(map(str, sorted(node.required)))
    forb_str = ','.join(map(str, sorted(node.forbidden)))
    return f"Required: [{req_str}] | Forbidden: [{forb_str}]"


# --- Main Execution ---

# 1. Load Data
train_loader, calibration_loader, test_loader = get_subset_loaders(
    NUM_TRAIN_SAMPLES, NUM_CALIBRATION_SAMPLES, NUM_TEST_SAMPLES
)

# 2. Train Model
model = SimpleNN()
print("\nTraining model...")
train_model(model, train_loader)
print("Training complete.")

# 3. Build NAP Tree
print("\nBuilding NAP tree using calibration data...")
correct_activations, correct_losses = extract_activations_and_losses(model, calibration_loader)

estimation_steps = 100000
start_time = time.time()

root_node = NapNode(indices=np.arange(len(correct_activations)))
root_node.sum_of_activations = np.sum(correct_activations, axis=0)
root_node.sample_count = len(correct_activations)
root_node.calculate_variance(correct_losses)

pq = [(-root_node.variance, root_node.id)]
heapq.heapify(pq)

tree_nodes = {root_node.id: root_node}
leaf_count = 1

# Tracking variables for split termination reasons
stopped_insufficient_samples = 0
stopped_no_discriminator = 0
stopped_no_median_split = 0

for estimation_step in range(estimation_steps):
    if not pq:
        break

    _, node_id_to_split = heapq.heappop(pq)
    parent_node = tree_nodes[node_id_to_split]

    if not parent_node.is_leaf:
        continue
        
    if len(parent_node.indices) < 2:
        stopped_insufficient_samples += 1
        continue

    parent_indices = parent_node.indices
    losses_in_nap = correct_losses[parent_indices]

    loss_median = np.median(losses_in_nap)
    high_loss_mask = (losses_in_nap > loss_median)
    low_loss_mask = ~high_loss_mask

    if not np.any(high_loss_mask) or not np.any(low_loss_mask):
        stopped_no_median_split += 1
        continue

    activations_in_nap = correct_activations[parent_indices]
    
    freq_high = np.mean(activations_in_nap[high_loss_mask], axis=0)
    freq_low = np.mean(activations_in_nap[low_loss_mask], axis=0)
    discriminative_scores = np.abs(freq_high - freq_low)
    
    used_neurons = parent_node.required + parent_node.forbidden
    if used_neurons:
      discriminative_scores[used_neurons] = -1
        
    best_neuron_idx = np.argmax(discriminative_scores)
    
    if discriminative_scores[best_neuron_idx] < 0:
        stopped_no_discriminator += 1
        continue

    parent_node.is_leaf = False
    leaf_count -= 1
    
    guide_activation_value = activations_in_nap[low_loss_mask][0, best_neuron_idx]
    
    split_column = activations_in_nap[:, best_neuron_idx]
    mask1 = (split_column == guide_activation_value)
    mask2 = ~mask1

    if guide_activation_value == 1:
        child1 = NapNode(indices=parent_indices[mask1], parent=parent_node, required=best_neuron_idx)
        child2 = NapNode(indices=parent_indices[mask2], parent=parent_node, forbidden=best_neuron_idx)
    else:
        child1 = NapNode(indices=parent_indices[mask1], parent=parent_node, forbidden=best_neuron_idx)
        child2 = NapNode(indices=parent_indices[mask2], parent=parent_node, required=best_neuron_idx)
    
    parent_node.children = [child1, child2]

    if len(child1.indices) <= len(child2.indices):
        small_child, large_child = child1, child2
    else:
        small_child, large_child = child2, child1

    small_child.sum_of_activations = np.sum(correct_activations[small_child.indices], axis=0)
    small_child.sample_count = len(small_child.indices)

    large_child.sum_of_activations = parent_node.sum_of_activations - small_child.sum_of_activations
    large_child.sample_count = parent_node.sample_count - small_child.sample_count
    
    for child in parent_node.children:
        tree_nodes[child.id] = child
        if len(child.indices) >= 2:
            child.calculate_variance(correct_losses)
            heapq.heappush(pq, (-child.variance, child.id))
            leaf_count += 1