# Helper Functions

## Evaluate Adversarial Robustness

In [None]:
import argparse
import torch.nn as nn
import utils
import random
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import models
import os


NETWORKS = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
    "resnet101": models.resnet101,
    "resnet152": models.resnet152,
}

DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
ITERS = [0, 4, 9, 14]
OUT_DIR = ""

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
parser.add_argument("--eps", type=float, default=8 / 255, help="fgsm attack eps")
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
criterion = torch.nn.CrossEntropyLoss()
metric_meter = utils.AvgMeter()


class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, inp):
        mean = self.mean[None, :, None, None]
        std = self.std[None, :, None, None]
        return inp.sub(mean).div(std)


def eval(loader, model, metric_meter, attack=False):
    metric_meter.reset()
    model.eval()
    for indx, (img, target) in enumerate(loader):
        img, target = img.to(device), target.to(device)

        if attack:
            img.requires_grad = True
            pred = model(img)
            cost = criterion(pred, target)
            grad = torch.autograd.grad(cost, img, retain_graph=False, create_graph=False)[0]
            adv_img = img + args.eps * grad.sign()
            img = torch.clamp(adv_img, min=0, max=1).detach()

        with torch.no_grad():
            pred = model(img)
            loss = criterion(pred, target)

        pred_cls = pred.argmax(dim=1)
        acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

        metrics = {"loss": loss.item(), "acc": acc}
        metric_meter.add(metrics)
        utils.pbar(indx / len(loader), msg=metric_meter.msg())
    utils.pbar(1, msg=metric_meter.msg())


f = open(f"{args.dset}_rob_a_results.txt", "w")
for data_size in DATA_SIZES:
    for aug in AUGS:
        for iter in ITERS:
            ckpt = os.path.join(
                OUT_DIR, f"sparse_advprop_{data_size}_{aug}", f"best_imp_{iter}.ckpt"
            )
            print(f"Evaluating: {ckpt}")
            ckpt = torch.load(ckpt)
            if args.dset == "cifar10":
                norm = Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                n_cls = 10
            elif args.dset == "cifar100":
                norm = Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                n_cls = 100
            else:
                raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
            model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small", pretrained=False).to(device)
            utils.modify_bn(model)
            setattr(
                model,
                "attacker",
                utils.PGDAttacker(args.attack_n_iter, args.attack_eps, args.attack_step_size, 0.2),
            )
            model = model.to(device)
            if iter:
                model.load_state_dict(ckpt["init"])
                curr_mask = utils.extract_mask(ckpt["model"])
                utils.mask_prune(model, curr_mask)
                print("remaining weight = ", utils.check_sparsity(model))
            model.load_state_dict(ckpt["model"])
            model = nn.Sequential(norm, model)
            model = model.to(device)

            # basic
            transform = transforms.ToTensor()
            dset = datasets.CIFAR10(
                root=args.data_root,
                train=False,
                transform=transform,
                download=True,
            )
            loader = DataLoader(
                dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
            )
            # eval(loader, model, metric_meter, attack=False)
            # metrics = metric_meter.get()
            # print(f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}")

            eval(loader, model, metric_meter, attack=True)
            metrics = metric_meter.get()
            print(
                f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
            )

            f.write(f"{round(metrics['acc'], 4)*100}" + " ")
            print("finished evaluating on ckpt")
            print("---------------------------")
        f.write("\n")
        f.flush()
f.close()

## Evaluate Robutness to Distribution Shifts

Download CIFAR10.2 from [https://github.com/modestyachts/cifar-10.2](https://github.com/modestyachts/cifar-10.2)

In [None]:
import argparse
import utils
import random
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import models
import os

NETWORKS = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
    "resnet101": models.resnet101,
    "resnet152": models.resnet152,
}

DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
ITERS = [0, 4, 9, 14]
OUT_DIR = ""

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
parser.add_argument(
    "--rob_data_root", type=str, required=True, help="path to transformed data directory"
)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
criterion = torch.nn.CrossEntropyLoss()
metric_meter = utils.AvgMeter()


@torch.no_grad()
def eval(loader, model, metric_meter):
    metric_meter.reset()
    model.eval()
    for indx, (img, target) in enumerate(loader):
        img, target = img.to(device), target.to(device)

        pred = model(img)
        loss = criterion(pred, target)

        pred_cls = pred.argmax(dim=1)
        acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

        metrics = {"loss": loss.item(), "acc": acc}
        metric_meter.add(metrics)
        utils.pbar(indx / len(loader), msg=metric_meter.msg())
    utils.pbar(1, msg=metric_meter.msg())


class CIFARRobustness(Dataset):
    def __init__(self, root, transform):
        train_imgs, train_labels = (
            np.load(os.path.join(root, "train.npz"))["images"],
            np.load(os.path.join(root, "train.npz"))["labels"],
        )
        test_imgs, test_labels = (
            np.load(os.path.join(root, "test.npz"))["images"],
            np.load(os.path.join(root, "test.npz"))["labels"],
        )
        self.imgs = np.concatenate([train_imgs, test_imgs], axis=0)
        self.labels = np.concatenate([train_labels, test_labels], axis=0)
        self.transform = transform

    def __getitem__(self, indx):
        img = self.imgs[indx]
        label = self.labels[indx]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


f = open(f"{args.dset}_rob_d_results.txt", "w")
for data_size in DATA_SIZES:
    for aug in AUGS:
        temp = []
        for iter in ITERS:
            ckpt = os.path.join(OUT_DIR, f"sparse_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
            print(f"Evaluating: {ckpt}")
            ckpt = torch.load(ckpt)
            if args.dset == "cifar10":
                norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                n_cls = 10
            elif args.dset == "cifar100":
                norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                n_cls = 100
            else:
                raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
            model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small").to(device)
            if iter:
                model.load_state_dict(ckpt["init"])
                curr_mask = utils.extract_mask(ckpt["model"])
                utils.mask_prune(model, curr_mask)
                print("remaining weight = ", utils.check_sparsity(model))
            model.load_state_dict(ckpt["model"])

            # basic
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    norm,
                ]
            )
            dset = datasets.CIFAR10(
                root=args.data_root,
                train=False,
                transform=transform,
                download=True,
            )
            loader = DataLoader(
                dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
            )
            eval(loader, model, metric_meter)
            metrics = metric_meter.get()
            print(
                f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
            )

            dset = CIFARRobustness(root=args.rob_data_root, transform=transform)
            loader = DataLoader(
                dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
            )
            eval(loader, model, metric_meter)
            metrics = metric_meter.get()
            print(
                f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
            )
            temp.append(str(round(metrics["acc"], 4) * 100))

            print("finished evaluating on ckpt")
            print("---------------------------")

        f.write(" ".join(temp) + "\n")
        f.flush()
f.close()

## Evaluate Robustness to Synthetic Corruptions 

Download CIFAR10-C from [https://zenodo.org/record/2535967](https://zenodo.org/record/2535967)

In [None]:
import argparse
import utils
import random
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import models
import os

NETWORKS = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
    "resnet101": models.resnet101,
    "resnet152": models.resnet152,
}

DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
ITERS = [0, 4, 9, 14]
OUT_DIR = ""

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
parser.add_argument(
    "--rob_data_root", type=str, required=True, help="path to transformed data directory"
)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

device, _ = utils.setup_device(False)
criterion = torch.nn.CrossEntropyLoss()
metric_meter = utils.AvgMeter()


@torch.no_grad()
def eval(loader, model, metric_meter):
    metric_meter.reset()
    model.eval()
    for indx, (img, target) in enumerate(loader):
        img, target = img.to(device), target.to(device)

        pred = model(img)
        loss = criterion(pred, target)

        pred_cls = pred.argmax(dim=1)
        acc = pred_cls.eq(target.view_as(pred_cls)).sum().item() / img.shape[0]

        metrics = {"loss": loss.item(), "acc": acc}
        metric_meter.add(metrics)
        utils.pbar(indx / len(loader), msg=metric_meter.msg())
    utils.pbar(1, msg=metric_meter.msg())


class CIFARRobustness(Dataset):
    TYPES = [
        # noise
        "gaussian_noise",
        "shot_noise",
        "impulse_noise",
        # blur
        "defocus_blur",
        "glass_blur",
        "motion_blur",
        "zoom_blur",
        # weather
        "snow",
        "frost",
        "fog",
        "brightness",
        # digital
        "contrast",
        "elastic_transform",
        "pixelate",
        "jpeg_compression",
        # extra
        "gaussian_blur",
        "saturate",
        "spatter",
        "speckle_noise",
    ]
    LEVELS = [1, 2, 3, 4, 5]

    def __init__(self, root, type, level, transform):
        assert type in self.TYPES
        assert level in self.LEVELS
        imgs = np.load(os.path.join(root, f"{type}.npy"))
        labels = np.load(os.path.join(root, "labels.npy"))
        self.imgs = imgs[(level - 1) * 10_000 : level * 10_000]
        self.labels = labels[(level - 1) * 10_000 : level * 10_000]
        self.transform = transform

    def __getitem__(self, indx):
        img = self.imgs[indx]
        label = self.labels[indx]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


f = open(f"{args.dset}_rob_s_results.txt", "w")
for data_size in DATA_SIZES:
    for aug in AUGS:
        for iter in ITERS:
            ckpt = os.path.join(OUT_DIR, f"sparse_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
            print(f"Evaluating: {ckpt}")
            ckpt = torch.load(ckpt)
            if args.dset == "cifar10":
                norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                n_cls = 10
            elif args.dset == "cifar100":
                norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
                n_cls = 100
            else:
                raise NotImplementedError(f"args.dset = {args.dset} not implemented.")
            model = NETWORKS[args.net](n_cls=n_cls, pre_conv="small").to(device)
            if iter:
                model.load_state_dict(ckpt["init"])
                curr_mask = utils.extract_mask(ckpt["model"])
                utils.mask_prune(model, curr_mask)
                print("remaining weight = ", utils.check_sparsity(model))
            model.load_state_dict(ckpt["model"])

            # basic
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    norm,
                ]
            )
            dset = datasets.CIFAR10(
                root=args.data_root,
                train=False,
                transform=transform,
                download=True,
            )
            loader = DataLoader(
                dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
            )
            eval(loader, model, metric_meter)
            metrics = metric_meter.get()
            print(
                f"{args.dset}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
            )

            temp = []
            for type in CIFARRobustness.TYPES:
                for level in CIFARRobustness.LEVELS:
                    dset = CIFARRobustness(
                        root=args.rob_data_root, type=type, level=level, transform=transform
                    )
                    loader = DataLoader(
                        dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers
                    )
                    eval(loader, model, metric_meter)
                    metrics = metric_meter.get()
                    print(
                        f"{args.dset} {type}_{level}: loss {round(metrics['loss'], 5)}, acc: {round(metrics['acc'], 5)}"
                    )
                    temp.append(str(round(metrics["acc"], 4) * 100))
            f.write(" ".join(temp) + "\n")
            f.flush()

            print("finished evaluating on ckpt")
            print("---------------------------")
f.close()

## Visualize Layer-wise Sparsity

In [None]:
import argparse
import utils
import random
import numpy as np
import torch
import models
import os

NETWORKS = {
    "resnet18": models.resnet18,
    "resnet34": models.resnet34,
    "resnet50": models.resnet50,
    "resnet101": models.resnet101,
    "resnet152": models.resnet152,
}

DATA_SIZES = ["1", "0.5", "0.2", "0.1", "0.02", "0.01"]
AUGS = ["baseaug", "contrastaug", "randaug", "autoaug"]
ITERS = [4, 9, 14]
OUT_DIR = ""
# OUT_DIR = "to_send/"

parser = argparse.ArgumentParser()
parser = utils.add_args(parser)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

CKPTS = [
    # best winning tickets
    "",
]

for data_size in DATA_SIZES:
    for iter in ITERS:
        for aug in AUGS:
            ckpt = os.path.join(OUT_DIR, f"sparse_{data_size}_{aug}", f"best_imp_{iter}.ckpt")
            ckpt = torch.load(ckpt, map_location=torch.device("cpu"))
            model = NETWORKS[args.net](n_cls=10, pre_conv="small")
            if iter:
                model.load_state_dict(ckpt["init"])
                curr_mask = utils.extract_mask(ckpt["model"])
                utils.mask_prune(model, curr_mask)
            model.load_state_dict(ckpt["model"])

            sparsity = []
            for name, m in model.named_modules():
                if isinstance(m, torch.nn.Conv2d):
                    actual = float(m.weight.nelement())
                    sparse = float(torch.sum(m.weight == 0))
                    sparsity.append(str(round((1 - sparse / actual) * 100, 2)))
            print(" ".join(sparsity))
        print("\n")
    print("\n")