# SIGE Benchmark on Progressive Distillation

## Preparations
### Installation (This may take several minute)

In [None]:
!pip install torch torchvision
!pip install sige
!pip install torchprofile gdown tqdm ipyplot pyyaml easydict

### Clone the Repository

In [None]:
!git clone https://github.com/lmxyy/sige.git

In [None]:
import os

os.chdir("sige/diffusion")
print(os.getcwd())


### Create Configurations

In [None]:
from argparse import Namespace

import torch
import yaml
from easydict import EasyDict

from utils import device_synchronize, get_device, set_seed

with open("configs/church_pd128-original.yml", "r") as f:
    config_vanilla = yaml.safe_load(f)
config_vanilla = EasyDict(config_vanilla)
with open("configs/church_pd128-sige.yml", "r") as f:
    config_sige = yaml.safe_load(f)
config_sige = EasyDict(config_sige)

device = get_device()
print("Device:", device)
config_vanilla.device = device
config_sige.device = device

# Build a dummy args
args = Namespace()
set_seed(0)  # for reproducibility, feel free to change the seed


### Create Models

Define Some helper functions for creating the models.

In [None]:
from typing import Optional

from torch import nn

from download_helper import get_ckpt_path
from models.ema import EMAHelper
from models.pd_arch.sige_unet import SIGEUNet
from models.pd_arch.unet import UNet


def build_model(args, config):
    network: str = config.model.network
    if network == "pd.unet":
        Model = UNet
    elif network == "pd.sige_unet":
        Model = SIGEUNet
    else:
        raise NotImplementedError("Unknown network [%s]!!!" % network)
    model = Model(args, config)
    model = model.to(device)

    if config.model.ema:
        ema_helper = EMAHelper(mu=config.model.ema_rate)
        ema_helper.register(model)
    else:
        ema_helper = None

    return model, ema_helper


def restore_checkpoint(model: nn.Module, ema_helper: Optional[EMAHelper], path: str):
    if isinstance(model, nn.DataParallel):
        model = model.module
    states = torch.load(path)
    model.load_state_dict(states["model"])
    if ema_helper is not None:
        if "ema" not in states:
            ema_helper.register(model)
        else:
            ema_helper.load_state_dict(states["ema"])
    return model, ema_helper


Build vanilla model. **It may take some time to download the model weights. Sometimes the downloading may get stuck. You could change the downloading tool to `torch_hub` with `tool="torch_hub"` and rerun the cell or download the weights manually.**

In [None]:
# Vanilla Progressive Distillation Model
vanilla_model, vanilla_ema_helper = build_model(args, config_vanilla)
pretrained_path = get_ckpt_path(config_vanilla, tool="gdown")
restore_checkpoint(vanilla_model, vanilla_ema_helper, pretrained_path)
vanilla_model.eval()
print("Vanilla model is built successfully!")


Build SIGE model. **It may take some time to download the model weights. Sometimes the downloading may get stuck. You could change the downloading tool to `torch_hub` with `tool="torch_hub"` and rerun the cell or download the weights manually.**

In [None]:
# SIGE Progressive Distillation Model
sige_model, sige_ema_helper = build_model(args, config_sige)
pretrained_path = get_ckpt_path(config_sige, tool="gdown")
restore_checkpoint(sige_model, sige_ema_helper, pretrained_path)
sige_model.eval()
print("SIGE model is built successfully!")


### Prepare Data
We have prepared two pairs of user edits in [`./assets`](./assets). Here, we use [`./assets/original.png`](./assets/original.png) as the original image and [`./assets/edited.png`](./assets/edited.png) as the edited image. You are free to use any other pairs of data either in our benchmark dataset (see [README.md](./README.md)) or created yourself.

In [None]:
from PIL import Image
from torchvision import transforms

original_image_path = "./assets/original.png"
edited_image_path = "./assets/edited.png"

original_image = Image.open(original_image_path)
edited_image = Image.open(edited_image_path)

assert config_vanilla.data.image_size == config_sige.data.image_size
image_size = config_vanilla.data.image_size
resize = transforms.Resize(image_size)
original_image = resize(original_image)
edited_image = resize(edited_image)


Display the images.

In [None]:
import ipyplot
import numpy as np

ipyplot.plot_images((np.array(original_image), np.array(edited_image)), ("Original", "Edited"))


Convert the images to tensors.

In [None]:
toTensor = transforms.ToTensor()
original_image = toTensor(original_image).unsqueeze(0).to(device)
edited_image = toTensor(edited_image).unsqueeze(0).to(device)

# Rescale the tensors to [-1, 1]
original_image = 2 * original_image - 1
edited_image = 2 * edited_image - 1

e = torch.randn_like(original_image)
x0s = torch.cat([original_image, edited_image], dim=0)
es = torch.cat([e, e], dim=0)


Compute the difference masks. `sige` has some helper functions for this.

In [None]:
from sige.utils import compute_difference_mask, dilate_mask, downsample_mask

assert config_vanilla.sampling.eps == config_sige.sampling.eps
assert config_vanilla.sampling.mask_dilate_radius == config_sige.sampling.mask_dilate_radius

eps = config_vanilla.sampling.eps
mask_dilate_radius = config_vanilla.sampling.mask_dilate_radius

difference_mask = compute_difference_mask(original_image, edited_image, eps=eps)
difference_mask = dilate_mask(difference_mask, mask_dilate_radius)

# Downsample the mask to different resolutions
masks = downsample_mask(difference_mask, image_size // (2 ** (len(config_vanilla.model.ch_mult) - 1)))


Visualize the masks.

In [None]:
def mask_to_image(mask: torch.Tensor):
    mask_numpy = mask.cpu().numpy()
    image = Image.fromarray(mask_numpy)
    image = image.resize((image_size, image_size))
    return image


mask_image = mask_to_image(difference_mask)
message = "Sparsity: %.2f%%" % (100 * difference_mask.sum() / difference_mask.numel())
ipyplot.plot_images((np.array(mask_image),), ("Difference Mask",), (message,), img_width=image_size)

print("Downsampled Masks")
arrays, labels, messages = [], [], []
for i, (k, v) in enumerate(masks.items()):
    image = mask_to_image(v)
    arrays.append(np.array(image))
    labels.append("Resolution: %dx%d" % (k[0], k[1]))
    messages.append("Sparsity: %.2f%%" % (100 * v.sum() / v.numel()))
ipyplot.plot_images(arrays, labels, messages, img_width=image_size)


## Test Models
### Quality Results

Define the Progressive Distillation sampler.

In [None]:
from samplers.pd_sampler import PDSampler as Sampler

# The same, the sampler should be the same for both the original model and SIGE model.
sampler = Sampler(args, config_vanilla)


Get some diffusion variables.

In [None]:
def get_sampling_sequence(config, noise_level=None):
    if noise_level is None:
        noise_level = config.total_steps

    skip_type = config.sampling.skip_type
    timesteps = config.sampling.sample_steps

    if skip_type == "uniform":
        skip = noise_level // timesteps
        seq = range(0, noise_level, skip)
    elif skip_type == "quad":
        seq = np.linspace(0, np.sqrt(noise_level * 0.8), timesteps - 1) ** 2
        seq = [int(s) for s in list(seq)]
        seq.append(noise_level)
    else:
        raise NotImplementedError("Unknown skip type [%s]!!!" % skip_type)
    return seq


# The sampling sequence should be the same for both the original model and SIGE model.
seq = get_sampling_sequence(config_vanilla, noise_level=config_vanilla.sampling.noise_level)
ts = torch.full((x0s.size(0),), seq[-1], device=x0s.device, dtype=torch.long)  # The starting timestep
xts = sampler.get_xt_from_x0(x0s, ts, es)  # Preturb the image with the corresponding noise level
gt_x0, gt_e = x0s[:1], es[:1]  # Used for the mask trick in SDEdit to keep the unedited regions unchanged.


Start denoising.
* Denoising with the vanilla model.

In [None]:
with torch.no_grad():
    vanilla_generate_x0s = sampler.denoising_steps(
        xts[1:], vanilla_model, seq, gt_x0=gt_x0, gt_e=gt_e, difference_mask=difference_mask
    )
    vanilla_result = (vanilla_generate_x0s[0] + 1) / 2


* Denoising with the SIGE model. Currently, `sige` only support caching the model for a single step. Therefore, we run a simple simulation to get the results: For every denoising step, we first denoise the noisy original image (`xts[:1]`) to cache the activations and run the actual sparse inference on `xts[1:]`.

In [None]:
with torch.no_grad():
    # Need a pre-run to determine the data shape
    sige_model.set_mode("full")
    sige_model(original_image, torch.zeros(1, device=device, dtype=torch.float32))

    # Set the difference masks for the sparse inference.
    # It will automatically reduce the masks to active indices.
    sige_model.set_masks(masks)

    sige_generate_x0s = sampler.denoising_steps(
        xts, sige_model, seq, gt_x0=gt_x0, gt_e=gt_e, difference_mask=difference_mask
    )
    sige_result = (sige_generate_x0s[1] + 1) / 2


Visualize these two images.

In [None]:
vanilla_result = vanilla_result.clip(0, 1).permute(1, 2, 0).cpu().numpy()
vanilla_result = (vanilla_result * 255).astype(np.uint8)
vanilla_image = Image.fromarray(vanilla_result)

sige_result = sige_result.clip(0, 1).permute(1, 2, 0).cpu().numpy()
sige_result = (sige_result * 255).astype(np.uint8)
sige_image = Image.fromarray(sige_result)


In [None]:
import ipyplot
import numpy as np

ipyplot.plot_images((np.array(vanilla_image), np.array(sige_image)), ("Vanilla", "SIGE"), img_width=image_size)


These two images should be very similar.

### Efficiency Results
#### $128\times128$ Model
First, let's profile the MACs of these two models.

In [None]:
from torchprofile import profile_macs

# Create some dummy inputs
dummy_inputs = (x0s[:1], ts[:1])

with torch.no_grad():
    vanilla_macs = profile_macs(vanilla_model, dummy_inputs)

    # For the SIGE model, we need to first run it in the `full`` mode to cache the results.
    sige_model.set_mode("full")
    sige_model(*dummy_inputs)
    # We also need to set the difference mask if not set.
    sige_model.set_masks(masks)

    # Check to the `profile` mode to profile MACs. This mode is only for the MACs profiling.
    sige_model.set_mode("profile")
    sige_macs = profile_macs(sige_model, dummy_inputs)


In [None]:
print("Vanilla MACs: %.3fG" % (vanilla_macs / 1e9))
print("SIGE MACs: %.3fG" % (sige_macs / 1e9))


SIGE model has a $\sim 1.5\times$ MACs reduction. This is less prominent than the DDIM result as Progressive Distillation only supports $128\times128$ images, and it is hard for SIGE to accelerate convolution with small resolution. Now let's measure the latency.

In [None]:
import time

from tqdm import tqdm

# Change these numbers if they are too large for you.
warmup_times = 100
test_times = 100


def measure_latency(model: nn.Module):
    for i in tqdm(range(warmup_times)):
        model(*dummy_inputs)
        device_synchronize(device)
    start_time = time.time()
    for i in tqdm(range(test_times)):
        model(*dummy_inputs)
        device_synchronize(device)
    cost_time = time.time() - start_time
    return cost_time, cost_time / test_times


with torch.no_grad():
    vanilla_cost, vanilla_avg = measure_latency(vanilla_model)

    # As we have already cached some dummy results for the SIGE model, no need to rerun it in the `full` mode.
    # Check to the `sparse` mode to test the latency.
    sige_model.set_mode("sparse")
    sige_cost, sige_avg = measure_latency(sige_model)


In [None]:
print("Vanilla: Cost %.2fs Avg %.2fms" % (vanilla_cost, vanilla_avg * 1000))
print("SIGE: Cost %.2fs Avg %.2fms" % (sige_cost, sige_avg * 1000))


#### $256\times256$ Model
We also provide configurations to some additional layers to adapt the Progressive Distillation to $256\times256$ resolution. Let's see how SIGE performs on the $256\times256$ model. Let's first build the models.

Get configurations.

In [None]:
with open("configs/church_pd256-original.yml", "r") as f:
    config_vanilla = yaml.safe_load(f)
config_vanilla = EasyDict(config_vanilla)
with open("configs/church_pd256-sige.yml", "r") as f:
    config_sige = yaml.safe_load(f)
config_sige = EasyDict(config_sige)


Build models. As we do not have the corresponding weights, we skip loading the weights. This will not hurt the efficiency results.
* Vanilla model.

In [None]:
# Vanilla Progressive Distillation Model
vanilla_model, vanilla_ema_helper = build_model(args, config_vanilla)
vanilla_model.eval()
print("Vanilla model is built successfully!")


* SIGE model.

In [None]:
# SIGE Progressive Distillation Model
sige_model, sige_ema_helper = build_model(args, config_sige)
sige_model.eval()
print("SIGE model is built successfully!")


Prepare data.

In [None]:
original_image = Image.open(original_image_path)
edited_image = Image.open(edited_image_path)

assert config_vanilla.data.image_size == config_sige.data.image_size
image_size = config_vanilla.data.image_size
resize = transforms.Resize(image_size)
original_image = resize(original_image)
edited_image = resize(edited_image)

ipyplot.plot_images((np.array(original_image), np.array(edited_image)), ("Original", "Edited"), img_width=image_size)

toTensor = transforms.ToTensor()
original_image = toTensor(original_image).unsqueeze(0).to(device)
edited_image = toTensor(edited_image).unsqueeze(0).to(device)

# Rescale the tensors to [-1, 1]
original_image = 2 * original_image - 1
edited_image = 2 * edited_image - 1


Compute the difference masks.

In [None]:
assert config_vanilla.sampling.eps == config_sige.sampling.eps
assert config_vanilla.sampling.mask_dilate_radius == config_sige.sampling.mask_dilate_radius

eps = config_vanilla.sampling.eps
mask_dilate_radius = config_vanilla.sampling.mask_dilate_radius

difference_mask = compute_difference_mask(original_image, edited_image, eps=eps)
difference_mask = dilate_mask(difference_mask, mask_dilate_radius)

# Downsample the mask to different resolutions
masks = downsample_mask(difference_mask, image_size // (2 ** (len(config_vanilla.model.ch_mult) - 1)))

mask_image = mask_to_image(difference_mask)
message = "Sparsity: %.2f%%" % (100 * difference_mask.sum() / difference_mask.numel())
ipyplot.plot_images((np.array(mask_image),), ("Difference Mask",), (message,), img_width=image_size)

print("Downsampled Masks")
arrays, labels, messages = [], [], []
for i, (k, v) in enumerate(masks.items()):
    image = mask_to_image(v)
    arrays.append(np.array(image))
    labels.append("Resolution: %dx%d" % (k[0], k[1]))
    messages.append("Sparsity: %.2f%%" % (100 * v.sum() / v.numel()))
ipyplot.plot_images(arrays, labels, messages, img_width=image_size)


Profile MACs.

In [None]:
dummy_inputs = (original_image, torch.zeros(1, device=device, dtype=torch.float32))

with torch.no_grad():
    vanilla_macs = profile_macs(vanilla_model, dummy_inputs)

    # For the SIGE model, we need to first run it in the `full`` mode to cache the results.
    sige_model.set_mode("full")
    sige_model(*dummy_inputs)
    sige_model.set_masks(masks)

    # Check to the `profile` mode to profile MACs. This mode is only for the MACs profiling.
    sige_model.set_mode("profile")
    sige_macs = profile_macs(sige_model, dummy_inputs)


In [None]:
print("Vanilla MACs: %.3fG" % (vanilla_macs / 1e9))
print("SIGE MACs: %.3fG" % (sige_macs / 1e9))


SIGE model has a $\sim 2.2\times$ MACs reduction now, much prominent the $128\times128$ model.

Measure latency.

In [None]:
with torch.no_grad():
    vanilla_cost, vanilla_avg = measure_latency(vanilla_model)

    # As we have already cached some dummy results for the SIGE model, no need to rerun it in the `full` mode.
    # Check to the `sparse` mode to test the latency.
    sige_model.set_mode("sparse")
    sige_cost, sige_avg = measure_latency(sige_model)


In [None]:
print("Vanilla: Cost %.2fs Avg %.2fms" % (vanilla_cost, vanilla_avg * 1000))
print("SIGE: Cost %.2fs Avg %.2fms" % (sige_cost, sige_avg * 1000))
