trying to change as few lines as possible - while still getting it to run on the cluster.

In [None]:
!pip install torch torchvision numpy tqdm matplotlib

In [None]:
import copy
import gc
import os
import random
from contextlib import contextmanager
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from tqdm import tqdm


def free_mem(): # keep like prev version
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.ipc_collect()


free_mem()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native,max_split_size_mb:512,garbage_collection_threshold:0.8,expandable_segments:True" # keep like prev version


# 
# context manager
# 


@contextmanager
def isolated_environment():
    # Save the current state of random seeds and numpy precision
    np_random_state = np.random.get_state()
    python_random_state = random.getstate()
    torch_random_state = torch.get_rng_state()
    cuda_random_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
    numpy_print_options = np.get_printoptions()
    try:
        # Execute the block of code
        yield
    finally:
        # Restore the saved state
        np.random.set_state(np_random_state)
        random.setstate(python_random_state)
        torch.set_rng_state(torch_random_state)
        if cuda_random_state:
            torch.cuda.set_rng_state_all(cuda_random_state)
        np.set_printoptions(**numpy_print_options)


#
# data
#


classes_path = Path.cwd().parent / "data"
dataset_path = Path.cwd().parent / "datasets"
weights_path = Path.cwd().parent / "weights"

os.makedirs(classes_path, exist_ok=True)
os.makedirs(dataset_path, exist_ok=True)
os.makedirs(weights_path, exist_ok=True)


classes = 100
assert classes == 100

if classes == 10:
    # Load the CIFAR-10 dataset
    trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True)
    testset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True)

    original_images_train_np = np.array(trainset.data)
    original_labels_train_np = np.array(trainset.targets)

    original_images_test_np = np.array(testset.data)
    original_labels_test_np = np.array(testset.targets)

elif classes == 100:
    # Load the CIFAR-100 dataset
    trainset = torchvision.datasets.CIFAR100(root=dataset_path, train=True, download=True)
    testset = torchvision.datasets.CIFAR100(root=dataset_path, train=False, download=True)

    original_images_train_np = np.array(trainset.data)
    original_labels_train_np = np.array(trainset.targets)

    original_images_test_np = np.array(testset.data)
    original_labels_test_np = np.array(testset.targets)

else:
    # assert False
    # Load the MNIST dataset
    trainset = torchvision.datasets.MNIST(
        root="./data",
        train=True,
        download=True,
    )
    testset = torchvision.datasets.MNIST(root="./data", train=False, download=True)

    original_images_train_np = np.array(trainset.data)
    original_labels_train_np = np.array(trainset.targets)

    original_images_train_np = np.stack([original_images_train_np] * 3, axis=3)

    original_images_test_np = np.array(testset.data)
    original_labels_test_np = np.array(testset.targets)

    original_images_test_np = np.stack([original_images_test_np] * 3, axis=3)

    classes = 10

# images between 0 and 1 instead of 0 and 255

images_train_np = original_images_train_np / 255.0
images_test_np = original_images_test_np / 255.0

labels_train_np = original_labels_train_np
labels_test_np = original_labels_test_np


#
# preprocessing
#


# to be able to replace the random number generator by other things if needed
def custom_rand(input_tensor, size):
    return torch.Tensor(np.random.rand(*size)).to("cuda")


def custom_choices(items, tensor):
    return np.random.choice(items, (len(tensor)))


# apply image augmentations to input images
def apply_transformations(
    images,
    down_res=224,
    up_res=224,
    jit_x=0,
    jit_y=0,
    down_noise=0.0,
    up_noise=0.0,
    contrast=1.0,
    color_amount=1.0,
):
    # # for MNIST alone
    # images = torch.mean(images,axis=1,keepdims=True)

    images_collected = []

    for i in range(images.shape[0]):
        image = images[i]

        # changing contrast
        image = torchvision.transforms.functional.adjust_contrast(image, contrast[i])

        # shift the result in x and y
        image = torch.roll(image, shifts=(jit_x[i], jit_y[i]), dims=(-2, -1))

        # shifting in the color <-> grayscale axis
        image = color_amount[i] * image + torch.mean(image, axis=0, keepdims=True) * (1 - color_amount[i])

        images_collected.append(image)

    images = torch.stack(images_collected, axis=0)

    # descrease the resolution
    images = F.interpolate(images, size=(down_res, down_res), mode="bicubic")

    # low res noise
    noise = down_noise * custom_rand(images + 312, (images.shape[0], 3, down_res, down_res)).to("cuda")
    images = images + noise

    # increase the resolution
    images = F.interpolate(images, size=(up_res, up_res), mode="bicubic")

    # high res noise
    noise = up_noise * custom_rand(images + 812, (images.shape[0], 3, up_res, up_res)).to("cuda")
    images = images + noise

    # clipping to the right range of values
    images = torch.clip(images, 0, 1)

    return images


resolutions = [32, 16, 8, 4]  # pretty arbitrary
down_noise = 0.2  # noise standard deviation to be added at the low resolution
up_noise = 0.2  # noise stadard deviation to be added at the high resolution
jit_size = 3  # max size of the x-y jit in each axis, sampled uniformly from -jit_size to +jit_size inclusive

# to shuffle randomly which image is which in the multi-res stack
# False for all experiments in the paper, good for ablations
shuffle_image_versions_randomly = False


def default_make_multichannel_input(images):
    return torch.concatenate([images] * len(resolutions), axis=1)


def make_multichannel_input(
    images,
    contrast=1.0,
    up_res=32,  # hard coded for CIFAR-10 or CIFAR-100
):
    all_channels = []

    for i, r in enumerate(resolutions):
        down_res = r

        jits_x = custom_choices(range(-jit_size, jit_size + 1), images + i)  # x-shift
        jits_y = custom_choices(range(-jit_size, jit_size + 1), 51 * images + 7 * i + 125 * r)  # y-shift
        contrasts = custom_choices(np.linspace(0.7, 1.5, 100), 7 + 3 * images + 9 * i + 5 * r)  # change in contrast
        color_amounts = contrasts = custom_choices(np.linspace(0.5, 1.0, 100), 5 + 7 * images + 8 * i + 2 * r)  # change in color amount

        images_now = apply_transformations(
            images,
            down_res=down_res,
            up_res=up_res,
            jit_x=jits_x,
            jit_y=jits_y,
            down_noise=down_noise,
            up_noise=up_noise,
            contrast=contrasts,
            color_amount=color_amounts,
        )

        all_channels.append(images_now)

    if not shuffle_image_versions_randomly:
        return torch.concatenate(all_channels, axis=1)
    elif shuffle_image_versions_randomly:
        indices = torch.randperm(len(all_channels))
        shuffled_tensor_list = [all_channels[i] for i in indices]
        return torch.concatenate(shuffled_tensor_list, axis=1)


# sample_images = images_test_np[:5]

# for j in [0, 1]:
#     multichannel_images = (
#         make_multichannel_input(
#             torch.Tensor(sample_images.transpose([0, 3, 1, 2])).to("cuda"),
#         )
#         .detach()
#         .cpu()
#         .numpy()
#         .transpose([0, 2, 3, 1])
#     )

#     N = 1 + multichannel_images.shape[3] // 3

#     plt.figure(figsize=(N * 5.5, 5))

#     plt.subplot(1, N, 1)
#     plt.title("original")
#     plt.imshow(sample_images[j])
#     plt.xticks([], [])
#     plt.yticks([], [])

#     for i in range(N - 1):
#         plt.subplot(1, N, i + 2)
#         plt.title(f"res={resolutions[i]}")
#         plt.imshow(multichannel_images[j, :, :, 3 * i : 3 * (i + 1)])
#         plt.xticks([], [])
#         plt.yticks([], [])

#     plt.show()

In [None]:
class BatchNormLinear(nn.Module):
    def __init__(self, in_features, out_features, device="cuda"):
        super(BatchNormLinear, self).__init__()
        self.batch_norm = nn.BatchNorm1d(in_features, device=device)
        self.linear = nn.Linear(in_features, out_features, device=device)

    def forward(self, x):
        x = self.batch_norm(x)
        return self.linear(x)


class WrapModelForResNet152(torch.nn.Module):
    def __init__(self, model, multichannel_fn, classes=10):
        super(WrapModelForResNet152, self).__init__()

        self.multichannel_fn = multichannel_fn

        self.model = model

        self.classes = classes

        self.layer_operations = [
            torch.nn.Sequential(
                model.conv1,
                model.bn1,
                model.relu,
                model.maxpool,
            ),
            *model.layer1,
            *model.layer2,
            *model.layer3,
            *model.layer4,
            model.avgpool,
            model.fc,
        ]

        self.all_dims = [
            3 * 224 * 224 * len(resolutions),
            64 * 56 * 56,
            *[256 * 56 * 56] * len(model.layer1),
            *[512 * 28 * 28] * len(model.layer2),
            *[1024 * 14 * 14] * len(model.layer3),
            *[2048 * 7 * 7] * len(model.layer4),
            2048,
            1000,
        ]

        self.linear_layers = torch.nn.ModuleList([BatchNormLinear(self.all_dims[i], classes, device="cuda") for i in range(len(self.all_dims))])

    def prepare_input(self, x):
        x = self.multichannel_fn(x)
        x = F.interpolate(x, size=(224, 224), mode="bicubic")
        x = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406] * (x.shape[1] // 3), std=[0.229, 0.224, 0.225] * (x.shape[1] // 3))(x)
        return x

    def forward_until(self, x, layer_id):
        x = self.prepare_input(x)

        for l in range(layer_id):
            if list(x.shape)[1:] == [2048, 1, 1]:
                x = x.reshape([-1, 2048])

            x = self.layer_operations[l](x)
        return x

    def forward_after(self, x, layer_id):
        x = self.prepare_input(x)

        for l in range(layer_id, len(self.layer_operations)):
            if list(x.shape)[1:] == [2048, 1, 1]:
                x = x.reshape([-1, 2048])

            x = self.layer_operations[l](x)
        return x

    def predict_from_layer(self, x, l):
        x = self.forward_until(x, l)
        x = x.reshape([x.shape[0], -1])
        return self.linear_layers[l](x)

    def predict_from_several_layers(self, x, layers):
        x = self.prepare_input(x)

        outputs = dict()

        outputs[0] = self.linear_layers[0](x.reshape([x.shape[0], -1]))

        for l in range(len(self.layer_operations)):
            if list(x.shape)[1:] == [2048, 1, 1]:
                x = x.reshape([-1, 2048])

            x = self.layer_operations[l](x)

            if l in layers:
                outputs[l + 1] = self.linear_layers[l + 1](x.reshape([x.shape[0], -1]))

        return outputs

resnet152_wrapper = WrapModelForResNet152(model.imported_model, make_multichannel_input, classes=classes) # <--------- this usually breaks
resnet152_wrapper.multichannel_fn = make_multichannel_input
resnet152_wrapper = resnet152_wrapper.to("cuda")

for layer_i in range(53):
    print(f"layer={layer_i} {resnet152_wrapper.predict_from_layer(torch.Tensor(np.zeros((2,3,32,32))).cuda(),layer_i).shape}")

class LinearNet(nn.Module):
    def __init__(self, model, layer_i):
        super(LinearNet, self).__init__()
        self.model = model
        self.layer_i = layer_i

    def forward(self, inputs):
        return self.model.predict_from_layer(inputs, self.layer_i)

backbone_model = copy.deepcopy(resnet152_wrapper)
del resnet152_wrapper

# only training some layers to save time -- super early ones are badon anything harder than CIFAR-10
layers_to_use = [20, 30, 35, 40, 45, 50, 52]

lr = 3.3e-5  # random stuff again
epochs = 1
batch_size = 64  # for CUDA RAM reasons

mode = "train"
backbone_model.eval()

linear_model = LinearNet(backbone_model, 5).to("cuda")  # just to have it ready

torch.cuda.empty_cache()

device = torch.device("cuda:0")

linear_layers_collected_dict = dict()

for layer_i in reversed(layers_to_use):
    print(f"///////// layer={layer_i} ///////////")

    linear_model.layer_i = layer_i
    linear_model.fixed_mode = "train"

    train_accs = []
    test_accs = []
    robust_accs = []
    clean_accs = []
    actual_robust_accs = []

    all_models = []

    torch.cuda.empty_cache()

    linear_model = train_model(
        linear_model,
        images_train_np[:],
        labels_train_np[:],
        epochs=epochs,
        lr=lr,
        optimizer_in=optim.Adam,
        batch_size=batch_size,
        mode=mode,
        subset_only=linear_model.model.linear_layers[layer_i].parameters(),  # just the linear projection
        use_adversarial_training=False,
        adversarial_epsilon=None,
    )

    linear_layers_collected_dict[layer_i] = copy.deepcopy(backbone_model.linear_layers[layer_i])