## Interval Analysis

In [None]:
# !pip install tensorboardX

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
# from tensorboardX import SummaryWriter

use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Simple NN. You can change this if you want. If you change it, mention the architectural details in your report.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200,10)

    def forward(self, x):
        x = x.view((-1, 28*28))
        x = F.relu(self.fc(x))
        x = self.fc2(x)
        x = F.softmax(x, dim=-1) # added softmax for probabilities
        return x

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

# Add the data normalization as a first "layer" to the network
# this allows us to search for adverserial examples to the real image, rather than
# to the normalized image
model = nn.Sequential(Normalize(), Net())

model = model.to(device)
model.train()


In [None]:
def train_model(model, num_epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.3f}')

def test_model(model):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(f'Accuracy on images: {100 * correct / total}')

In [None]:
train_model(model, 10)
test_model(model)

### Write the interval analysis for the simple model

In [None]:
## TODO: Write the interval analysis for the simple model
## you can use https://github.com/Zinoex/bound_propagation

import math


def normalize_bounds(
    l, u
):
    return (l - 0.1307) / 0.3081, (u - 0.1307) / 0.3081



def ibp_linear(
    l, u, W, b
):
    W_pos = torch.clamp(W, min=0)
    W_neg = torch.clamp(W, max=0)
    l_out = l @ W_pos.T + u @ W_neg.T + b
    u_out = u @ W_pos.T + l @ W_neg.T + b
    return l_out, u_out



def ibp_relu(l, u):
    return torch.clamp(l, min=0), torch.clamp(u, min=0)



def logits_bounds_for_batch(
    model, images, eps
):
    B = images.shape[0]
    device = images.device

    l = torch.clamp(images - eps, 0.0, 1.0)
    u = torch.clamp(images + eps, 0.0, 1.0)

    l, u = normalize_bounds(l, u)

    l = l.view(B, -1)
    u = u.view(B, -1)

    net = model[1]  
    W1 = net.fc.weight  # [200, 784]
    b1 = net.fc.bias  # [200]
    W2 = net.fc2.weight  # [10, 200]
    b2 = net.fc2.bias  # [10]

    l1, u1 = ibp_linear(l, u, W1, b1)
    l1, u1 = ibp_relu(l1, u1)
    l2, u2 = ibp_linear(l1, u1, W2, b2)
    return l2, u2 


def robust_verified_mask_from_bounds(
    l_logits, u_logits, labels
):
    B, C = l_logits.shape
    l_true = l_logits.gather(1, labels.view(-1, 1)).squeeze(1)
    u_masked = u_logits.clone()
    u_masked[torch.arange(B, device=labels.device), labels] = float("-inf")
    max_u_other, _ = u_masked.max(dim=1)
    margin_lb = l_true - max_u_other
    return margin_lb > 0


def evaluate_verified_accuracy(
    model, loader, eps
):
    model.eval()
    total = 0
    correct = 0
    verified_all = 0
    verified_and_correct = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        B = images.size(0)

        probs = model(images)  
        preds = probs.argmax(dim=1)
        is_correct = preds.eq(labels)

        l_logits, u_logits = logits_bounds_for_batch(model, images, eps)
        verified_mask = robust_verified_mask_from_bounds(l_logits, u_logits, labels)

        total += B
        correct += is_correct.sum().item()
        verified_all += verified_mask.sum().item()
        verified_and_correct += (verified_mask & is_correct).sum().item()

    natural_acc = correct / total
    verified_acc_over_all = verified_all / total
    verified_acc = verified_and_correct / total  
    return natural_acc, verified_acc_over_all, verified_acc


def sweep_epsilons_and_report(model, loader, eps_values):
    rows = []
    nat_acc = None
    for eps in eps_values:
        na, va_all, vra = evaluate_verified_accuracy(model, loader, eps)
        if nat_acc is None:
            nat_acc = na
        rows.append(
            {
                "epsilon": float(eps),
                "verified_over_all": va_all,
                "verified_acc": vra,
            }
)
        print(
            f"eps={eps:.3f} | verified_over_all={va_all*100:.2f}% | verified_acc={vra*100:.2f}%"
)
    print(f"\Clean accuracy: {nat_acc*100:.2f}%")

    eps_t = torch.tensor([r["epsilon"] for r in rows])
    vra_t = torch.tensor([r["verified_acc"] for r in rows]) * 100
    plt.figure(figsize=(6, 4))
    plt.plot(eps_t, vra_t, "o-", label="Verified accuracy")
    plt.xlabel("epsilon (L-infinity)")
    plt.ylabel("Verified accuracy (%)")
    plt.title("MNIST Verified Robustness /w IBP")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()
    return rows


epsilon_values = torch.linspace(0.01, 0.10, steps=10).tolist()
results = sweep_epsilons_and_report(model, test_loader, epsilon_values)