# Notebook: Different Attacks Induce Diverse Energy Landscapes

Corresponding to the section "Different Attacks Induce Diverse Energy Landscapes" of the paper: "*Shedding More Light on Robust Classifiers under the Lens of Energy-based Model.* "




## Import necessary libraries and initialize settings


In [None]:
import torch, gc
import torch
import numpy as np
import torch.nn.functional as F
import torchvision.datasets as datasets
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from robustbench.utils import load_model
from torch import nn
import random
from robustbench.utils import load_model
import torchattacks
from matplotlib import pyplot as plt



gc.collect()
torch.cuda.empty_cache()
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)



wandb_log=True
#attack_details
config={
# attack details
"dataset": "CIFAR-10",
"attack_type" : "PGD",
"epsilon" : 8/255,
"targeted": True,
"alpha" : 2/255,
"gamma" : 0,
"steps": 2,
# adv_Model_deatils
"model_name":'Standard',
"from_RobustBench":True,
"dataset_trained_on" : 'CIFAR-10',
"target_label": 'random',
"kappa": 0,
"batch_size": 256
}

if wandb_log==True:
    import wandb
    wandb.init(
    # set the wandb project where this run will be logged
        project="<project_name>",
        save_code=True,
        config=config,
        entity="robustgen"
    )
    dataset=wandb.config['dataset']
    attack_type=wandb.config['attack_type']
    targeted=wandb.config['targeted']
    epsilon=wandb.config['epsilon']
    alpha=wandb.config['alpha']
    factor=wandb.config['gamma']
    steps=wandb.config['steps']
    model_name = wandb.config['model_name']
    from_RobustBench = wandb.config['from_RobustBench']
    dataset_trained_on = wandb.config['dataset_trained_on']
    target_label = wandb.config['target_label']
    kappa = wandb.config['kappa']
    batch_size=wandb.config['batch_size']

    current_run_name = wandb.run.name
    print(current_run_name)
else:
    dataset=config['dataset']
    attack_type=config['attack_type']
    epsilon=config['epsilon']
    alpha=config['alpha']
    factor=config['gamma']
    steps=config['steps']
    targeted=config['targeted']
    model_name = config['model_name']
    from_RobustBench = config['from_RobustBench']
    dataset_trained_on = config['dataset_trained_on']
    target_label = config['target_label']
    kappa = config['kappa']
    batch_size=config['batch_size']


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')



## Utility functions


In [None]:
# Get accuracy of the model on original data (test set)
def test_model(model, test_dataloader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for img, labels in test_dataloader:
            img = img.to(device)
            labels = labels.to(device)

            outputs = model(img)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# Get accuracy of the model on adversarial data (test set)
def measure_accuracy(model, test_dataloader, epsilon, alpha, steps):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    # Define the PGD attack
    attack = torchattacks.PGD(model, eps=epsilon, alpha=alpha, steps=steps)

    for images, labels in test_dataloader:

        images = images.to(device)
        labels = labels.to(device)

        # Perturb the images using the PGD attack
        perturbed_images = attack(images, labels)

        outputs = model(perturbed_images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Robust Accuracy: {accuracy:.2f}%')
    return accuracy

# Compute the energy as the negative logarithm of the sum of the exponentials of the logits
def compute_energy(logits):
    energy = -torch.logsumexp(logits, dim=1)
    return energy

# Compute joint energy wrt ground truth label
def compute_energy_xy(logits, labels):
    # Get the logit corresponding to the correct label
    correct_logits = logits[torch.arange(logits.size(0)), labels]
    energy = - correct_logits
    return energy

# Compute energy vectors for real data
def compute_energy_vector(model, test_dataloader):
    model.eval()
    energy_vector = []
    energy_vector_xy = []

    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            logits = model(images)
            energy = compute_energy(logits)
            energy_xy = compute_energy_xy(logits, labels)

            energy_vector.extend(energy.cpu().numpy())
            energy_vector_xy.extend(energy_xy.cpu().numpy())

    return np.array(energy_vector), np.array(energy_vector_xy)


def compute_attack(model, test_dataloader, eps=8/255, alpha=2/255, steps=10, targeted=False, attack=attack_type):
    """return the energy vector of the test set after attack.
    Args:
        model (_type_): Pytorch model
        test_dataloader (_type_): dataloader of the test set
        epsilon (float): maximum perturbation of adversaries
        alpha (float): step size of PGD attack
        steps (int): number of steps of PGD attack
        targeted: False if the attack is untargeted
        device: device cuda

    Returns:
        _type_: energy vector of the test set
    """
    model.eval()  # Set the model to evaluation mode
    energy_vector = []
    energy_vector_xy = []
    correct = 0
    total = 0

    if attack == 'PGD':
        attack = torchattacks.PGD(model, eps=eps, alpha=alpha, steps=steps)

    elif attack == 'TRADES-PGD':
        attack = torchattacks.TPGD(model, eps=eps, alpha=alpha, steps=steps)

    elif attack == 'APGD-T':
        attack = torchattacks.APGDT(model, norm='Linf', eps=eps, steps=steps, n_restarts=1, seed=0, verbose=False, n_classes=10)

    elif attack == 'APGD':
        attack = torchattacks.APGD(model, norm='Linf', eps=eps, steps=steps, n_restarts=1, seed=0, loss='ce', verbose=False)

    elif attack == 'APGD-DLR':
        attack = torchattacks.APGD(model, norm='Linf', eps=eps, steps=steps, n_restarts=1, seed=0, loss='dlr', verbose=False)

    elif attack == 'FAB':
        attack = torchattacks.FAB(model, norm='Linf', steps=steps, eps=8/255, n_restarts=1, verbose=True, multi_targeted=False, n_classes=10)

    elif attack == 'Square':
        attack = torchattacks.Square(model, norm='Linf', eps=8/255, n_queries=1000, n_restarts=1, seed=0, verbose=False, loss='margin')

    elif attack == 'CW':
        attack = torchattacks.CW(model, c=1, kappa=kappa, steps=steps, lr=0.01)
        if targeted:
            print('targeted attack CW')
            if target_label=='random':
                print('targeted attack CW')
                attack.targeted=True
                attack.set_mode_targeted_random()
            elif target_label=='least':
                print('targeted attack CW')
                attack.targeted=True
                attack.get_least_likely_label()


    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        # Perturb the images using the PGD attack
        perturbed_images = attack(images, labels)

        logits = model(perturbed_images)

        _, predicted = torch.max(logits.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        energy = compute_energy(logits)
        energy_xy = compute_energy_xy(logits, labels)

        energy_vector.extend(energy.detach().cpu().numpy())
        energy_vector_xy.extend(energy_xy.detach().cpu().numpy())


    if wandb_log==True:
        wandb.log({
               "robust accuracy top1": 100*correct/total
               })

    return np.array(energy_vector), np.array(energy_vector_xy)


# Plot the histogram of clean and adversarial energies computed
def plot_energy(clean_energy, adversarial_energy, bins=100, name="energy", name_model=model_name+'_'+attack_type):

    x = np.stack((clean_energy, adversarial_energy),axis=1)
    fig, ax = plt.subplots(figsize =(7, 5))

    ax.set_ylim(0,700)

    ax.set_xlim(-45,20)

    ax.hist(x, bins=np.linspace(-50,20,bins), histtype='bar' , color=['lightblue','red'], label=["clean data", 'adversarial data'], stacked=False )

    if name=="energy":
        ax.set_xlabel('E(x)', fontsize=27)
    else:
        ax.set_xlabel('E(x,y)', fontsize=27)

    ax.set_ylabel('# samples',  fontsize=27)
    plt.legend(loc='upper left', fontsize=20)
    fig.set_size_inches(8, 6)
    plt.tick_params(axis='x', labelsize=18)
    plt.tick_params(axis='y', labelsize=18)

    # Show the plot
    plt.savefig(name_model +'_'+ name +'.pdf', bbox_inches='tight')
    plt.show()



### Load test data and model





In [None]:
# Define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load CIFAR-10 test dataset
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


# Create the dataloaders
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)

#importing from robust bench
if from_RobustBench:
    model = load_model(model_name=model_name, dataset='cifar10', threat_model="Linf", model_dir="robustness2/models")

    model = nn.DataParallel(model, device_ids=[0, 1])
    model.to(device)

#using local model
else:
    model_path='/home/robustgen/Downloads/PhD/notebooks/models/wandb_model/'+ model_name +'.pt'
    model = WideResNet().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()






### Calculate accuracies, energy vectors on clean and adversarial data

In [None]:
top1 = test_model(model, test_dataloader)
if wandb_log==True:
    wandb.log({"clean accuracy top 1": top1,
               })

v_e, v_exy = compute_energy_vector(model, test_dataloader)
v_adv_e, v_adv_exy = compute_attack(model, test_dataloader, attack=attack_type, steps=steps, targeted=targeted)

Plot Marginal Energy i.e. E(x) for clean and adversarial data

In [None]:
plot_energy(v_e, v_adv_e, bins=200)

Plot Joint Energy wrt to ground truth label y, i.e. E(x,y) for clean and adversarial data

In [None]:
plot_energy(v_exy, v_adv_exy, bins=200, name="<enter_name>")

### Save energies if needed (optional)

In [None]:
if not os.path.exists('./data_vector'):
    os.makedirs('./data_vector')


# Save the energy vectors as npy files
np.save("./data_vector/clean_energy.npy", v_e)
np.save("./data_vector/adversarial_energy.npy", v_adv_e)
np.save("./data_vector/clean_energy_xy.npy", v_exy)
np.save("./data_vector/adversarial_energy_xy.npy", v_adv_exy)
# Log the vectors into wandb
if wandb_log == True:
    wandb.save("./data_vector/clean_energy.npy")
    wandb.save("./data_vector/adversarial_energy.npy")
    wandb.save("./data_vector/clean_energy_xy.npy")
    wandb.save("./data_vector/adversarial_energy_xy.npy")

if wandb_log==True:
    wandb.finish()
