# Benchmarking Hugging Face Accelerate/`xaitk-saliency` Integration

This notebook utilizes PyTorch's benchmarking capability, along with [`submitit`](https://github.com/facebookincubator/submitit), to anaylze the integration strategy used for Hugging Face Accelerate and `xaitk-saliency`.

## Table of Contents

* [Environment Setup](#environment-setup)
* [Benchmarking](#benchmarking)
  * [GPU Sweep](#gpu-sweep)
  * [Mask Sweep](#mask-sweep)

## Environment Setup <a name="environment-setup"></a>

In [1]:
import sys

!{sys.executable} -m pip install -qU pip
print("Installing xaitk-jatic...")
!{sys.executable} -m pip install -q ../..
print("Installing xaitk-saliency...")
!{sys.executable} -m pip install -q xaitk-saliency
print("Installing smqtk-classifier...")
!{sys.executable} -m pip install -qU smqtk-classifier
print("Installing Hugging Face datasets...")
!{sys.executable} -m pip install -q datasets
print("Installing Hugging Face transformers...")
!{sys.executable} -m pip install -q transformers
print("Installing Hugging Face accelerate...")
!{sys.executable} -m pip install -q accelerate
print("Installing submitit...")
!{sys.executable} -m pip install -q 'submitit'
print("Done!")

Installing xaitk-jatic...
Installing xaitk-saliency...
Installing smqtk-classifier...
Installing Hugging Face datasets...
Installing Hugging Face transformers...
Installing Hugging Face accelerate...
Installing submitit...
Done!


In [2]:
# Note PREDICT_SIZE should be >= BATCH_SIZE, due to the way Accelerate distributes data
BATCH_SIZE = 25
PREDICT_SIZE = 100
MASKED_DATA_BATCH_SIZE = 128

min_run_time = 30

In [3]:
%matplotlib inline
from matplotlib import pyplot as plt

# Use JPEG format for inline visualizations
%config InlineBackend.figure_format = "jpeg"

import submitit
from submitit.core.core import Executor

import torch.utils.benchmark as benchmark

import numpy as np
import time
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset, get_worker_info
from torchvision import transforms
from datasets import load_dataset
from transformers import AutoModelForImageClassification
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed

from scipy.special import softmax
from typing import Iterable, Optional
from smqtk_classifier.interfaces.classify_image import ClassifyImage
from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.rise import RISEStack
from xaitk_saliency.interfaces.gen_image_classifier_blackbox_sal import GenerateImageClassifierBlackboxSaliency

# For "artifact tracking" (to compare results)
import pickle

The following is code from the original integration notebook:

In [4]:
def app(
    saliency_generator: GenerateImageClassifierBlackboxSaliency,
    use_accelerate: bool = True,
    display_results: bool = False,
    results_filepath: Optional[str] = None,
):
    class TestDataset(Dataset):
        def __init__(self, data):
            self.data = data
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                    transforms.Resize((224, 224), antialias=True),
                ]
            )

        def __getitem__(self, index):
            return self.transform(self.data[index])

        def __len__(self):
            return len(self.data)

    accelerator = None
    if use_accelerate:
        # For reproducability
        set_seed(42)

        # Set up the accelerator
        accelerator = Accelerator(even_batches=False)

    # Get the model
    model_name = "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10"
    model = AutoModelForImageClassification.from_pretrained(model_name)

    # Predicting on a subset of the CIFAR10 Test dataset
    ds = load_dataset("cifar10", split="test")
    labels = ds.features["label"].names
    num_classes = len(labels)
    ds_shuffle = ds.shuffle(seed=42)
    images = ds_shuffle[0:PREDICT_SIZE]["img"]
    dataloader = DataLoader(TestDataset(images), batch_size=min(BATCH_SIZE, PREDICT_SIZE))

    if accelerator:
        # Prepare the model and dataloader for use with accelerate
        model, dataloader = accelerator.prepare(model, dataloader)

    image_classifier = AccelerateClassifier(model, labels, accelerator, transform=None)

    # Generate saliency maps
    sal_maps_set = []
    for batch in dataloader:
        b = batch.cpu().data.numpy()
        for img in b:
            sal_maps = saliency_generator(np.moveaxis(img, 0, -1), image_classifier)
            sal_maps_set.append(sal_maps)

    if accelerator:
        accelerator.wait_for_everyone()
        t_sal_maps_set = torch.Tensor(np.array(sal_maps_set)).to(accelerator.device)
        sal_maps_set_gathered = accelerator.gather(t_sal_maps_set)
        sal_maps_set_gathered = sal_maps_set_gathered.data.cpu().numpy()
    else:
        sal_maps_set_gathered = sal_maps_set

    # Plot each image in set with saliency maps
    if display_results and (accelerator is None or accelerator.is_main_process):
        for i in range(len(images)):
            plt.figure(figsize=(10, 5))
            num_cols = np.ceil(num_classes / 2).astype(int) + 1
            plt.subplot(2, num_cols, 1)
            plt.imshow(images[i], cmap="gray")
            plt.xticks(())
            plt.yticks(())

            for c in range(num_cols - 1):
                plt.subplot(2, num_cols, c + 2)
                plt.imshow(sal_maps_set_gathered[i][c], cmap=plt.cm.RdBu, vmin=-1, vmax=1)
                plt.xticks(())
                plt.yticks(())
                plt.xlabel(f"{labels[c]}")
            for c in range(num_classes - num_cols + 1, num_classes):
                plt.subplot(2, num_cols, c + 3)
                plt.imshow(sal_maps_set_gathered[i][c], cmap=plt.cm.RdBu, vmin=-1, vmax=1)
                plt.xticks(())
                plt.yticks(())
                plt.xlabel(f"{labels[c]}")

    # Save results for comparison for examples sake
    if results_filepath is not None and (accelerator is None or accelerator.is_main_process):
        pickle.dump(sal_maps_set_gathered, open(results_filepath, "wb"))

In [5]:
class AccelerateClassifier(ClassifyImage):
    def __init__(
        self,
        model: torch.nn.Module,
        labels: list,
        accelerator: Optional[Accelerator] = None,
        transform: Optional[transforms.transforms.Compose] = None,
    ):
        self.model = model
        self.accelerator = accelerator
        self.labels = labels
        self.transform = transform

    def get_labels(self):
        return self.labels

    class ClassifyImagesDataset(IterableDataset):
        def __init__(
            self, iterable: Iterable[np.ndarray], device=None, transform: Optional[transforms.transforms.Compose] = None
        ):
            self._iterable = iterable
            self._device = device
            self._transform = transform

        def __iter__(self):
            tnsfm = self._transform
            device = self._device

            for image in self._iterable:
                image = np.moveaxis(image, -1, 0)
                if tnsfm:
                    item = tnsfm(image)
                else:
                    item = image
                if device:
                    item = torch.Tensor(item).to(device)
                yield item

    def classify_images(self, image_iter):
        dataloader = DataLoader(
            self.ClassifyImagesDataset(image_iter, self.accelerator.device if self.accelerator else None),
            batch_size=MASKED_DATA_BATCH_SIZE,
            shuffle=False,
        )

        self.model.eval()
        results = []
        for batch in dataloader:
            with torch.no_grad():
                preds = softmax(self.model(batch).logits.data.cpu().numpy(), axis=1)
            results.extend([{la: p for p, la in zip(pred, self.labels)} for pred in preds])

        return results

    # Required for implementation
    def get_config(self):
        return {}

## Benchmarking <a name="benchmarking"></a>

We'll benchmark against (1) a varying number of GPUs and (2) a varying number of masks to see how this affects computation time.

We'll first define a utility function to more easily submit jobs via submitit:

In [6]:
def run_app(app, *args, **kwargs):
    job = executor.submit(notebook_launcher, *args, **kwargs)
    job.results()

### GPU Sweep <a name="gpu-sweeip"></a>

In [7]:
gen_sliding_window = SlidingWindowStack(window_size=(14, 14), stride=(7, 7), threads=4)

gpu_benchmark_results = []

gpus = [1, 2, 4]

for g in gpus:
    label = f"GPU Sweep ({PREDICT_SIZE} sample images)"
    sub_label = f"{g} GPU"

    executor = submitit.AutoExecutor(folder="submitit_logs", cluster="slurm")
    executor.update_parameters(gpus_per_node=g, slurm_partition="community", slurm_account="xai", timeout_min=180)
    args = (
        app,
        (
            gen_sliding_window,
            True,  # use_accelerate
            False,  # display_results
            None,
        ),  # results_filepath
    )
    kwargs = {"num_processes": g}

    print(f"Starting GPU sweep test: {g} GPU")
    gpu_benchmark_results.append(
        benchmark.Timer(
            stmt="run_app(app, *args, **kwargs)",
            setup="from __main__ import run_app",
            globals={"app": app, "args": args, "kwargs": kwargs},
            label=label,
            sub_label=sub_label,
            description="time",
        ).blocked_autorange(min_run_time=min_run_time)
    )

compare = benchmark.Compare(gpu_benchmark_results)
compare.print()

Starting GPU sweep test: 1 GPU
Starting GPU sweep test: 2 GPU
Starting GPU sweep test: 4 GPU
[ GPU Sweep (100 sample images) ]
             |   time
1 threads: ----------
      1 GPU  |  414.8
      2 GPU  |  219.5
      4 GPU  |  127.4

Times are in seconds (s).



### Mask Sweep <a name="mask-sweep"></a>

In [8]:
mask_benchmark_results = []

n_masks = [50, 100, 200, 400]
gpus = 4

for n in n_masks:
    label = f"Mask Sweep ({PREDICT_SIZE} sample images)"
    sub_label = f"{n} Masks"

    executor = submitit.AutoExecutor(folder="submitit_logs", cluster="slurm")
    executor.update_parameters(gpus_per_node=gpus, slurm_partition="community", slurm_account="xai", timeout_min=180)
    kwargs = {"num_processes": gpus}

    gen_rise_stack = RISEStack(n=n, s=8, p1=0.5, seed=0, threads=4)

    print(f"Starting number masks test: {n} masks")
    args = (
        app,
        (
            gen_rise_stack,
            True,  # use_accelerate
            False,  # display_results
            None,
        ),  # results_filepath
    )
    mask_benchmark_results.append(
        benchmark.Timer(
            stmt="run_app(app, *args, **kwargs)",
            setup="from __main__ import run_app",
            globals={"app": app, "args": args, "kwargs": kwargs},
            label=label,
            sub_label=sub_label,
            description="time",
        ).blocked_autorange(min_run_time=min_run_time)
    )

compare = benchmark.Compare(mask_benchmark_results)
compare.print()

Starting number masks test: 50 masks
Starting number masks test: 100 masks
Starting number masks test: 200 masks
Starting number masks test: 400 masks
[ Mask Sweep (100 sample images) ]
                 |  time
1 threads: -------------
      50 Masks   |  14.2
      100 Masks  |  20.2
      200 Masks  |  36.3
      400 Masks  |  60.3

Times are in seconds (s).

