We start from an Imagenet-pretrained ResNet152, replace its first and last layers, make it ready to use multi-resolution inputs generated from a single image, and finetune on CIFAR-100.

We then train separate linear layers for each of the intermediate representations on top of a frozen backbone model.

From these, we form a self-ensemble. We evalaute its adversarial accuracy on CIFAR-100 using the RobustBench AutoAttack at the end (with the `rand` flag enable).

The whole Colab should take ~60 minutes on an A100 GPU and should be self-contained.

It should give you above/about SOTA adversarial robustness on CIFAR-100 under $L_\infty = 8/255$ attacks already, visualize the successfully attacked images and also visualize the class prototypes optimized directly from pixels.

In [None]:
!nvidia-smi
!pip install torch torchvision numpy tqdm matplotlib

In [2]:
from types import SimpleNamespace
import json
import copy
import hashlib
import os
import random
import time
from contextlib import contextmanager
from pathlib import Path
from torchvision.models import resnet152, ResNet152_Weights
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torchvision import datasets, models
from torchvision.models import resnet152
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import random
import hashlib
import time
import copy
import matplotlib.pyplot as plt
from contextlib import contextmanager
from tqdm import tqdm

assert torch.cuda.is_available()

#
# config
#

classes_path = Path.cwd().parent / "data"
dataset_path = Path.cwd().parent / "datasets"
weights_path = Path.cwd().parent / "weights"

os.makedirs(classes_path, exist_ok=True)
os.makedirs(dataset_path, exist_ok=True)
os.makedirs(weights_path, exist_ok=True)

#
# seed
#

# seed = 41
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True


@contextmanager
def isolated_environment():
    # save and restore random states in a context manager
    # used to separate random-seed-fixing behavior from the attacks later
    np_random_state = np.random.get_state()
    python_random_state = random.getstate()
    torch_random_state = torch.get_rng_state()
    cuda_random_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
    numpy_print_options = np.get_printoptions()
    try:
        yield  # execute code block
    finally:
        np.random.set_state(np_random_state)
        random.setstate(python_random_state)
        torch.set_rng_state(torch_random_state)
        if cuda_random_state:
            torch.cuda.set_rng_state_all(cuda_random_state)
        np.set_printoptions(**numpy_print_options)


#
# data
#

classes = 100

if classes == 10:
    trainset = datasets.CIFAR10(root=dataset_path, train=True, download=True)
    testset = datasets.CIFAR10(root=dataset_path, train=False, download=True)
    original_images_train_np = np.array(trainset.data)
    original_labels_train_np = np.array(trainset.targets)
    original_images_test_np = np.array(testset.data)
    original_labels_test_np = np.array(testset.targets)
    classes_cifar10 = json.loads((classes_path / "cifar10_classes.json").read_text())
elif classes == 100:
    trainset = datasets.CIFAR100(root=dataset_path, train=True, download=True)
    testset = datasets.CIFAR100(root=dataset_path, train=False, download=True)
    original_images_train_np = np.array(trainset.data)
    original_labels_train_np = np.array(trainset.targets)
    original_images_test_np = np.array(testset.data)
    original_labels_test_np = np.array(testset.targets)
    classes_cifar100 = json.loads((classes_path / "cifar100_classes.json").read_text())
else:
    assert False

images_train_np = original_images_train_np / 255.0  # scale to [0, 1]
images_test_np = original_images_test_np / 255.0
labels_train_np = original_labels_train_np
labels_test_np = original_labels_test_np

#
# multi resolution preprocessing (channel layer)
#


def custom_rand(input_tensor, size):
    return torch.Tensor(np.random.rand(*size)).to("cuda")


def custom_choices(items, tensor):
    return np.random.choice(items, (len(tensor)))


resolutions = [32, 16, 8, 4]  # pretty arbitrary
shuffle_image_versions_randomly = False  # to shuffle randomly which image is which in the multi-res stack (false in paper)
transform = True  # to apply the transformations or not (true in paper)


def default_make_multichannel_input(images):
    return torch.concatenate([images] * len(resolutions), axis=1)


def apply_transformations(images, down_res, up_res, jit_x, jit_y, down_noise, up_noise, contrast, color_amount):
    # images = torch.mean(images,axis=1,keepdims=True) # for MNIST alone

    images_collected = []
    for i in range(images.shape[0]):
        image = images[i]
        image = torchvision.transforms.functional.adjust_contrast(image, contrast[i])  # changing contrast
        image = torch.roll(image, shifts=(jit_x[i], jit_y[i]), dims=(-2, -1))  # shift the result in x and y
        image = color_amount[i] * image + torch.mean(image, axis=0, keepdims=True) * (1 - color_amount[i])  # shifting in the color <-> grayscale axis
        images_collected.append(image)

    images = torch.stack(images_collected, axis=0)

    images = F.interpolate(images, size=(down_res, down_res), mode="bicubic")  # descrease the resolution
    noise = down_noise * custom_rand(images + 312, (images.shape[0], 3, down_res, down_res)).to("cuda")  # low res noise
    images = images + noise

    images = F.interpolate(images, size=(up_res, up_res), mode="bicubic")  # increase the resolution
    noise = up_noise * custom_rand(images + 812, (images.shape[0], 3, up_res, up_res)).to("cuda")  # high res noise
    images = images + noise

    images = torch.clip(images, 0, 1)  # clipping to the right range of values
    return images


def make_multichannel_input(images):
    all_channels = []
    if transform:
        for i, r in enumerate(resolutions):
            jit_size = 3  # max size of the x-y jit in each axis, sampled uniformly from -jit_size to +jit_size inclusive
            images_now = apply_transformations(
                images,
                down_res=r,
                up_res=32,  # hard coded for CIFAR-10 or CIFAR-100
                jit_x=custom_choices(range(-jit_size, jit_size + 1), images + i),  # x-shift
                jit_y=custom_choices(range(-jit_size, jit_size + 1), 51 * images + 7 * i + 125 * r),  # y-shift
                down_noise=0.2,  # noise standard deviation to be added at the low resolution
                up_noise=0.2,  # noise stadard deviation to be added at the high resolution
                contrast=custom_choices(np.linspace(0.5, 1.0, 100), 5 + 7 * images + 8 * i + 2 * r),  # change in contrast
                color_amount=custom_choices(np.linspace(0.5, 1.0, 100), 5 + 7 * images + 8 * i + 2 * r),  # change in color amount
            )
            all_channels.append(images_now)
    else:
        all_channels = [images] * len(resolutions)

    if not shuffle_image_versions_randomly:
        return torch.concatenate(all_channels, axis=1)
    elif shuffle_image_versions_randomly:
        indices = torch.randperm(len(all_channels))
        shuffled_tensor_list = [all_channels[i] for i in indices]
        return torch.concatenate(shuffled_tensor_list, axis=1)

In [None]:
# demo
sample_images = images_test_np[:5]

for j in [0, 1]:
    multichannel_images = make_multichannel_input(torch.Tensor(sample_images.transpose([0, 3, 1, 2])).to("cuda")).detach().cpu().numpy().transpose([0, 2, 3, 1])

    N = 1 + multichannel_images.shape[3] // 3

    plt.figure(figsize=(N * 5.5, 5))

    plt.subplot(1, N, 1)
    plt.title("original")
    plt.imshow(sample_images[j])
    plt.xticks([], [])
    plt.yticks([], [])

    for i in range(N - 1):
        plt.subplot(1, N, i + 2)
        plt.title(f"res={resolutions[i]}")
        plt.imshow(multichannel_images[j, :, :, 3 * i : 3 * (i + 1)])
        plt.xticks([], [])
        plt.yticks([], [])

    plt.show()

In [None]:
def eval_model(model, images_in, labels_in, batch_size=128):
    all_preds = []
    all_logits = []

    with torch.no_grad():
        its = int(np.ceil(float(len(images_in)) / float(batch_size)))

        pbar = tqdm(range(its), desc="Eval", ncols=100)

        for it in pbar:
            i1 = it * batch_size
            i2 = min([(it + 1) * batch_size, len(images_in)])

            inputs = torch.Tensor(images_in[i1:i2].transpose([0, 3, 1, 2])).to("cuda")
            outputs = model(inputs)

            all_logits.append(outputs.detach().cpu().numpy())

            preds = torch.argmax(outputs, axis=-1)
            all_preds.append(preds.detach().cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_logits = np.concatenate(all_logits, axis=0)

    return np.sum(all_preds == labels_in), all_preds.shape[0], all_logits


from torchvision.models import resnet152, ResNet152_Weights

imported_model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V2)

# fixed for ResNet152 first conv layer, change for others by hand
in_planes = 3
planes = 64
stride = 2
N = len(resolutions)  # input channels multiplier due to multi-res input

conv2 = nn.Conv2d(N * in_planes, planes, kernel_size=7, stride=stride, padding=3, bias=False)

# replacing the pretrained conv with new init of the multi-res one
imported_model.conv1 = copy.deepcopy(conv2)

# getting the final layer to predcit the right number of classes -> new init
imported_model.fc = nn.Linear(2048, classes)


class ImportedModelWrapper(nn.Module):
    def __init__(self, imported_model, multichannel_fn):
        super(ImportedModelWrapper, self).__init__()
        self.imported_model = imported_model
        self.multichannel_fn = multichannel_fn

    def forward(self, x):
        x = self.multichannel_fn(x)
        x = F.interpolate(x, size=(224, 224), mode="bicubic")
        x = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406] * (x.shape[1] // 3), std=[0.229, 0.224, 0.225] * (x.shape[1] // 3))(x)
        x = self.imported_model(x)
        return x


wrapped_model = ImportedModelWrapper(imported_model, make_multichannel_input).to("cuda")
wrapped_model.multichannel_fn = make_multichannel_input


# to get light adversarial training going, off by default
def fgsm_attack(model, xs, ys, epsilon, random_reps=1, batch_size=64):
    model = model.eval()

    its = int(np.ceil(xs.shape[0] / batch_size))

    all_perturbed_images = []

    for it in range(its):
        i1 = it * batch_size
        i2 = min([(it + 1) * batch_size, xs.shape[0]])

        x = torch.Tensor(xs[i1:i2].transpose([0, 3, 1, 2])).to("cuda")
        y = torch.Tensor(ys[i1:i2]).to("cuda").to(torch.long)

        x.requires_grad = True

        for _ in range(random_reps):
            outputs = model(x)
            loss = nn.CrossEntropyLoss()(outputs, y)
            loss.backward()

        perturbed_image = x + epsilon * x.grad.data.sign()
        perturbed_image = torch.clip(perturbed_image, 0, 1)

        all_perturbed_images.append(perturbed_image.detach().cpu().numpy().transpose([0, 2, 3, 1]))

    return np.concatenate(all_perturbed_images, axis=0)


def train_model(
    model_in,
    images_in,
    labels_in,
    epochs=10,
    lr=1e-3,
    batch_size=512,
    optimizer_in=optim.Adam,
    subset_only=None,
    mode="eval",
    use_adversarial_training=False,
    adversarial_epsilon=8 / 255,
    skip_test_set_eval=False,
):
    global storing_models

    if mode == "train":
        model_in.train()
    elif mode == "eval":
        model_in.eval()

    criterion = nn.CrossEntropyLoss()

    if subset_only is None:
        train_optimizer = optimizer_in(model_in.parameters(), lr=lr)
    else:
        train_optimizer = optimizer_in(subset_only, lr=lr)

    for epoch in range(epochs):
        randomized_ids = np.random.permutation(range(len(images_in)))

        # making sure the model is in the right eval/train mode every epoch
        # due to the potential switching by black-box evals applied
        if mode == "train":
            model_in.train()
        elif mode == "eval":
            model_in.eval()
        else:
            assert False

        its = int(np.ceil(float(len(images_in)) / float(batch_size)))
        pbar = tqdm(range(its), desc="Training", ncols=100)

        all_hits = []

        for it in pbar:
            i1 = it * batch_size
            i2 = min([(it + 1) * batch_size, len(images_in)])

            ids_now = randomized_ids[i1:i2]

            np_images_used = images_in[ids_now]
            np_labels_used = labels_in[ids_now]

            inputs = torch.Tensor(np_images_used.transpose([0, 3, 1, 2])).to("cuda")

            # very light adversarial training if on
            if use_adversarial_training:
                attacked_images = fgsm_attack(
                    model_in.eval(),
                    np_images_used[:],
                    np_labels_used[:],
                    epsilon=adversarial_epsilon,
                    random_reps=1,
                    batch_size=batch_size // 2,
                )
                np_images_used = attacked_images
                np_labels_used = np_labels_used

                if mode == "train":
                    model_in.train()
                elif mode == "eval":
                    model_in.eval()

            inputs = torch.Tensor(np_images_used.transpose([0, 3, 1, 2])).to("cuda")
            labels = torch.Tensor(np_labels_used).to("cuda").to(torch.long)

            # zero the parameter gradients
            train_optimizer.zero_grad()

            inputs_used = inputs

            # the actual optimization step
            outputs = model_in(inputs_used)
            loss = criterion(outputs, labels)
            loss.backward()
            train_optimizer.step()

            # on the fly eval for the train set batches
            preds = torch.argmax(outputs, axis=-1)
            acc = torch.mean((preds == labels).to(torch.float), axis=-1)
            all_hits.append((preds == labels).to(torch.float).detach().cpu().numpy())
            train_accs.append(acc.detach().cpu().numpy())

            pbar.set_description(f"train acc={acc.detach().cpu().numpy()} loss={loss.item()}")

        if not skip_test_set_eval:
            with isolated_environment():
                eval_model_copy = copy.deepcopy(model_in)
                test_hits, test_count, _ = eval_model(eval_model_copy.eval(), images_test_np, labels_test_np)
        else:
            # to avoid dividing by zero
            test_hits = 0
            test_count = 1

        # end of epoch eval
        train_hits = np.sum(np.concatenate(all_hits, axis=0).reshape([-1]))
        train_count = np.concatenate(all_hits, axis=0).reshape([-1]).shape[0]
        print(f"e={epoch} train {train_hits} / {train_count} = {train_hits/train_count},  test {test_hits} / {test_count} = {test_hits/test_count}")

        test_accs.append(test_hits / test_count)

    print("\nFinished Training")

    return model_in


lr = 3.3e-5  # found with very simple "grid search" by hand, likely not optimal!
mode = "train"

epochs = 6

model = copy.deepcopy(wrapped_model)
model.multichannel_fn = make_multichannel_input

if mode == "eval":
    model = model.eval()
elif mode == "train":
    model = model.train()
else:
    assert False

train_accs = []
test_accs = []

torch.cuda.empty_cache()

device = torch.device("cuda:0")

# with torch.autocast("cuda"):
model = train_model(
    model,
    images_train_np,
    labels_train_np,
    epochs=epochs,
    lr=lr,
    optimizer_in=optim.Adam,
    batch_size=128,
    mode=mode,
)

In [None]:
#
# model
#

imported_model = resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
imported_model.fc = nn.Linear(2048, classes)  # set num of classes

# update first conv layer for multi-res
in_planes = 3
planes = 64
stride = 2
N = len(resolutions)  # input channels multiplier due to multi-res input
conv2 = nn.Conv2d(N * in_planes, planes, kernel_size=7, stride=stride, padding=3, bias=False)
imported_model.conv1 = copy.deepcopy(conv2)


class ImportedModelWrapper(nn.Module):
    def __init__(self, imported_model, multichannel_fn):
        super(ImportedModelWrapper, self).__init__()
        self.imported_model = imported_model
        self.multichannel_fn = multichannel_fn

    def forward(self, x):
        # our custom preprocessing
        x = self.multichannel_fn(x)
        # default resnet preprocessing
        x = F.interpolate(x, size=(224, 224), mode="bicubic")
        x = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406] * (x.shape[1] // 3), std=[0.229, 0.224, 0.225] * (x.shape[1] // 3))(x)
        x = self.imported_model(x)
        return x


wrapped_model = ImportedModelWrapper(imported_model, make_multichannel_input).to("cuda")
wrapped_model.multichannel_fn = make_multichannel_input
model = copy.deepcopy(wrapped_model)
model.multichannel_fn = make_multichannel_input
model = model.train()

#
# train
#


def eval_model(model, images_in, labels_in, batch_size=128):
    all_preds = []
    all_logits = []

    with torch.no_grad():
        its = int(np.ceil(float(len(images_in)) / float(batch_size)))
        pbar = tqdm(range(its), desc="Eval", ncols=100)
        for it in pbar:
            i1 = it * batch_size
            i2 = min([(it + 1) * batch_size, len(images_in)])

        inputs = torch.Tensor(images_in[i1:i2].transpose([0, 3, 1, 2])).to("cuda")
        outputs = model(inputs)
        all_logits.append(outputs.detach().cpu().numpy())
        preds = torch.argmax(outputs, axis=-1)
        all_preds.append(preds.detach().cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_logits = np.concatenate(all_logits, axis=0)
    return np.sum(all_preds == labels_in), all_preds.shape[0], all_logits


def fgsm_attack(model, xs, ys, epsilon, random_reps=1, batch_size=64):  # optional light adv training (false in paper)
    model = model.eval()
    its = int(np.ceil(xs.shape[0] / batch_size))
    all_perturbed_images = []
    for it in range(its):
        i1 = it * batch_size
        i2 = min([(it + 1) * batch_size, xs.shape[0]])
    x = torch.Tensor(xs[i1:i2].transpose([0, 3, 1, 2])).to("cuda")
    y = torch.Tensor(ys[i1:i2]).to("cuda").to(torch.long)

    x.requires_grad = True
    for _ in range(random_reps):
        outputs = model(x)
        loss = nn.CrossEntropyLoss()(outputs, y)
        loss.backward()

    perturbed_image = x + epsilon * x.grad.data.sign()
    perturbed_image = torch.clip(perturbed_image, 0, 1)
    all_perturbed_images.append(perturbed_image.detach().cpu().numpy().transpose([0, 2, 3, 1]))
    return np.concatenate(all_perturbed_images, axis=0)


def train_model(
    # model_in,
    # images_in,
    # labels_in,
    # epochs = 10,
    # lr = 1e-3,
    # batch_size = 512,
    # optimizer_in = optim.Adam,
    # subset_only = None,
    # mode = "eval",
    # use_adversarial_training = False,
    # adversarial_epsilon = 8/255,
    # skip_test_set_eval = False,
    model_in,
    images_in,
    labels_in,
    epochs,
    lr,
    batch_size,
    optimizer_in,
    subset_only,
    mode,
    use_adversarial_training,
    adversarial_epsilon,
    skip_test_set_eval,
):
    global storing_models

    if mode == "train":
        model_in.train()
    elif mode == "eval":
        model_in.eval()
    else:
        assert False

    criterion = nn.CrossEntropyLoss()

    if subset_only is None:
        train_optimizer = optimizer_in(model_in.parameters(), lr=lr)
    else:
        train_optimizer = optimizer_in(subset_only, lr=lr)

    for epoch in range(epochs):
        randomized_ids = np.random.permutation(range(len(images_in)))
        if mode == "train":  # bugfix
            model_in.train()
        elif mode == "eval":
            model_in.eval()
        its = int(np.ceil(float(len(images_in)) / float(batch_size)))
        pbar = tqdm(range(its), desc="Training", ncols=100)

        all_hits = []
        for it in pbar:
            i1 = it * batch_size
            i2 = min([(it + 1) * batch_size, len(images_in)])
            ids_now = randomized_ids[i1:i2]
            np_images_used = images_in[ids_now]
            np_labels_used = labels_in[ids_now]

            # light adversarial training
            if use_adversarial_training:
                inputs = torch.Tensor(np_images_used.transpose([0, 3, 1, 2])).to("cuda")
                attacked_images = fgsm_attack(model_in.eval(), np_images_used[:], np_labels_used[:], epsilon=adversarial_epsilon, random_reps=1, batch_size=batch_size // 2)
                np_images_used = attacked_images
                np_labels_used = np_labels_used
                if mode == "train":
                    model_in.train()
                elif mode == "eval":
                    model_in.eval()

            # forward and optimize
            inputs = torch.Tensor(np_images_used.transpose([0, 3, 1, 2])).to("cuda")
            labels = torch.Tensor(np_labels_used).to("cuda").to(torch.long)
            train_optimizer.zero_grad()
            outputs = model_in(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            train_optimizer.step()

            # eval on trainset
            preds = torch.argmax(outputs, axis=-1)
            acc = torch.mean((preds == labels).to(torch.float), axis=-1)
            all_hits.append((preds == labels).to(torch.float).detach().cpu().numpy())
            train_accs.append(acc.detach().cpu().numpy())
            pbar.set_description(f"train acc={acc.detach().cpu().numpy()} loss={loss.item()}")

        # eval on testset
        if not skip_test_set_eval:
            with isolated_environment():
                eval_model_copy = copy.deepcopy(model_in)
                test_hits, test_count, _ = eval_model(eval_model_copy.eval(), images_test_np, labels_test_np)
        train_hits = np.sum(np.concatenate(all_hits, axis=0).reshape([-1]))
        train_count = np.concatenate(all_hits, axis=0).reshape([-1]).shape[0]
        print(f"e={epoch} train {train_hits} / {train_count} = {train_hits/train_count},  test {test_hits} / {test_count} = {test_hits/test_count}")
        test_accs.append((test_hits / test_count) if (not skip_test_set_eval and test_count > 0) else 0)
    print("done")
    return model_in


train_accs = []
test_accs = []

torch.cuda.empty_cache()
with torch.autocast("cuda"):
    model = train_model(
        model_in=model,
        images_in=images_train_np,
        labels_in=labels_train_np,
        epochs=6,  # takes 2h on A100
        lr=3.3e-5,  # likely not optimal
        batch_size=128,
        optimizer_in=optim.Adam,
        subset_only=None,
        mode="train",
        use_adversarial_training=False,
        adversarial_epsilon=8 / 255,
        skip_test_set_eval=False,
    )

# device = torch.device("cuda:0")