In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, 1000)
        self.fc4 = nn.Linear(1000, 1000)
        self.fc5 = nn.Linear(1000, 10)
        
        # Store ReLU layers for NAP analysis
        self.relu_layers = []

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

def train_model(model, train_loader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    model.to(device)
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total_preds += target.size(0)
            correct_preds += (predicted == target).sum().item()
            
            progress_bar.set_postfix(loss=(running_loss / (batch_idx + 1)), 
                                   accuracy=100. * correct_preds / total_preds)
    print("Finished Training")

def get_train_loader(batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    loader_kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    return train_loader, train_dataset

def eval_model_accuracy(model):
    model.eval()
    correct = 0
    total = 0
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

def get_all_activations(model, inputs):
    """Extract activations from both ReLU layers in SimpleNN"""
    model.eval()
    activations_list = []
    
    with torch.no_grad():
        x = inputs.view(-1, 28 * 28)
        
        # First layer activations
        h1 = F.relu(model.fc1(x))
        activations_list.append(h1)
        
        # Second layer activations  
        h2 = F.relu(model.fc2(h1))
        activations_list.append(h2)

        h3 = F.relu(model.fc3(h2))
        activations_list.append(h3)

        h4 = F.relu(model.fc4(h3))
        activations_list.append(h4)
        
    # Concatenate all activations
    return torch.cat(activations_list, dim=1)

def compute_input_nap(activations):
    """Convert activations to binary NAP"""
    return (activations > 0).int()   

def compute_baseline_nap(model, train_dataset, target_class, delta):
    """Compute baseline NAP specification for a target class"""
    print(f"\nComputing baseline NAP specification for class {target_class} with delta={delta}...")
    
    # Create a subset of the dataset for the target class
    class_indices = [i for i, label in enumerate(train_dataset.targets) if label == target_class]
    class_subset = Subset(train_dataset, class_indices)
    class_loader = DataLoader(class_subset, batch_size=256, shuffle=False)
    
    all_class_naps = []
    model.eval()
    
    with torch.no_grad():
        for inputs, _ in tqdm(class_loader, desc="Extracting Activations"):
            inputs = inputs.to(device)
            activations = get_all_activations(model, inputs)
            binary_naps = compute_input_nap(activations)
            all_class_naps.append(binary_naps.cpu())
            
    all_class_naps = torch.cat(all_class_naps, dim=0)
    num_samples, num_neurons = all_class_naps.shape
    
    # Apply statistical abstraction (A_tilde) as per equation (6)
    activation_counts = torch.sum(all_class_naps, dim=0)
    deactivation_counts = num_samples - activation_counts
    
    baseline_nap = torch.full((num_neurons,), 2, dtype=torch.int)  # 2 represents '*'
    
    # Apply delta threshold
    for i in range(num_neurons):
        activation_freq = activation_counts[i].item() / num_samples
        deactivation_freq = deactivation_counts[i].item() / num_samples
        
        if activation_freq >= delta:
            baseline_nap[i] = 1  # State '1'
        elif deactivation_freq >= delta:
            baseline_nap[i] = 0  # State '0'
        # Otherwise remains '*' (state 2)
            
    num_binary_neurons = (baseline_nap != 2).sum().item()
    print(f"Baseline NAP computed. Size (number of binary neurons): {num_binary_neurons}/{num_neurons}")
    return baseline_nap.to(device)

# Configuration
BATCH_SIZE = 256
EPOCHS = 10
DELTA = 0.999  # Confidence ratio for the statistical NAP function (A_tilde)

# Training
print("--- Starting Model Training ---")
model = SimpleNN()
train_loader, train_dataset = get_train_loader(BATCH_SIZE)
train_model(model, train_loader, epochs=EPOCHS)
eval_model_accuracy(model)

# NAP Analysis
print(f"\n=== Fitting Baseline NAPs for All Classes ===")
naps = []
for target_class in range(10):
    print(f"\n--- Processing Class {target_class} ---")
    class_nap = compute_baseline_nap(model, train_dataset, target_class, DELTA)
    naps.append(class_nap)

print(f"\n=== NAP Analysis Complete ===")
print(f"Total classes processed: {len(naps)}")
for i, nap in enumerate(naps):
    num_binary_neurons = (nap != 2).sum().item()
    total_neurons = nap.shape[0]
    print(f"Class {i}: {num_binary_neurons}/{total_neurons} binary neurons")

# NAP Evaluation Functions
def nap_subsumes(class_nap, input_nap):
    return torch.all((class_nap == 2) | (class_nap == input_nap))

# Test dataset setup
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
results = {i: [0, 0, 0, 0] for i in range(10)}

print("\n=== Running Final NAP Evaluation ===")
model.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Final Evaluation"):
        inputs, labels = inputs.to(device), labels.to(device)
        activations = get_all_activations(model, inputs)
        input_naps = compute_input_nap(activations)
                  
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        for i in range(len(labels)):
            true_label = labels[i].item()
            predicted_label = predicted[i].item()
            class_nap = naps[true_label]
            input_nap = input_naps[i]
            is_covered = nap_subsumes(class_nap, input_nap)
            is_correct = (predicted_label == true_label)
            if is_correct:
                results[true_label][0] += 1
                if is_covered:
                    results[true_label][1] += 1
            else:
                results[true_label][2] += 1
                if is_covered:
                    results[true_label][3] += 1

print("\n--- Evaluation Results ---")
for class_id, data in results.items():
    correct_total, correct_covered, incorrect_total, incorrect_covered = data
         
    # Coverage over correct examples
    coverage_correct = (correct_covered / correct_total) * 100 if correct_total > 0 else 0
    # Coverage over incorrect examples
    coverage_incorrect = (incorrect_covered / incorrect_total) * 100 if incorrect_total > 0 else 0
         
    print(f"\n--- Class {class_id} ---")
    print(f"Coverage over correctly classified examples: {coverage_correct:.2f}% ({correct_covered}/{correct_total})")
    print(f"Coverage over misclassified examples: {coverage_incorrect:.2f}% ({incorrect_covered}/{incorrect_total})")

print("\n--------------------------")

Using device: cpu
--- Starting Model Training ---


Epoch 1/10:  51%|█████     | 119/235 [00:02<00:02, 44.94it/s, accuracy=87.7, loss=0.382]


KeyboardInterrupt: 

In [None]:
# --- Evaluation Results ---

# --- Class 0 ---
# Coverage over correctly classified examples: 78.70% (761/967)
# Coverage over misclassified examples: 0.00% (0/13)

# --- Class 1 ---
# Coverage over correctly classified examples: 87.35% (987/1130)
# Coverage over misclassified examples: 0.00% (0/5)

# --- Class 2 ---
# Coverage over correctly classified examples: 87.52% (884/1010)
# Coverage over misclassified examples: 0.00% (0/22)

# --- Class 3 ---
# Coverage over correctly classified examples: 83.72% (828/989)
# Coverage over misclassified examples: 0.00% (0/21)

# --- Class 4 ---
# Coverage over correctly classified examples: 84.47% (821/972)
# Coverage over misclassified examples: 0.00% (0/10)

# --- Class 5 ---
# Coverage over correctly classified examples: 83.43% (730/875)
# Coverage over misclassified examples: 0.00% (0/17)

# --- Class 6 ---
# Coverage over correctly classified examples: 78.76% (738/937)
# Coverage over misclassified examples: 0.00% (0/21)

# --- Class 7 ---
# Coverage over correctly classified examples: 85.23% (860/1009)
# Coverage over misclassified examples: 0.00% (0/19)

# --- Class 8 ---
# Coverage over correctly classified examples: 87.79% (820/934)
# Coverage over misclassified examples: 2.50% (1/40)

# --- Class 9 ---
# Coverage over correctly classified examples: 86.61% (815/941)
# Coverage over misclassified examples: 25.00% (17/68)

# --------------------------