# SIGE Benchmark on GauGAN

## Preparations
### Installation (This may take several minute)

In [None]:
!pip install torch torchvision
!pip install sige
!pip install torchprofile gdown tqdm ipyplot pyyaml easydict

### Clone the Repository

In [None]:
!git clone https://github.com/lmxyy/sige.git

In [None]:
import os

os.chdir("sige/gaugan")
print(os.getcwd())


### Get arguments

In [None]:
import argparse


def get_args(args_str: str):
    parser = argparse.ArgumentParser()

    # Model related
    parser.add_argument("--netG", type=str, default="spade")
    parser.add_argument("--ngf", type=int, default=64)
    parser.add_argument("--input_nc", type=int, default=35)
    parser.add_argument("--output_nc", type=int, default=3)
    parser.add_argument(
        "--separable_conv_norm",
        type=str,
        default="instance",
        choices=("none", "instance", "batch"),
        help="whether to use instance norm for the separable convolutions",
    )
    parser.add_argument(
        "--norm_G", type=str, default="spadesyncbatch3x3", help="instance normalization or batch normalization"
    )
    parser.add_argument(
        "--num_upsampling_layers",
        choices=("normal", "more", "most"),
        default="more",
        help="If 'more', adds upsampling layer between the two middle resnet blocks. "
        "If 'most', also add one more upsampling + resnet layer at the end of the generator",
    )
    parser.add_argument(
        "--norm",
        type=str,
        default="instance",
        help="instance normalization or batch normalization [instance | batch | none]",
    )
    parser.add_argument(
        "--config_str", type=str, default=None, help="the configuration string for a specific subnet in the supernet"
    )

    # Data related
    parser.add_argument("--crop_size", type=int, default=512)
    parser.add_argument("--no_instance", action="store_true")
    parser.add_argument("--aspect_ratio", type=int, default=2)

    # SIGE related
    parser.add_argument("--main_block_size", type=int, default=6)
    parser.add_argument("--shortcut_block_size", type=int, default=4)
    parser.add_argument("--num_sparse_layers", type=int, default=5)
    parser.add_argument("--mask_dilate_radius", type=int, default=1)
    parser.add_argument("--downsample_dilate_radius", type=int, default=2)

    args = parser.parse_args(args_str.split(" "))
    args.semantic_nc = args.input_nc + (0 if args.no_instance else 1)
    return args


In [None]:
# Arguments for the vanilla GauGAN (or original GauGAN)
vanilla_args = get_args("--netG spade")
# Arguments for vanilla GauGAN with SIGE
sige_vanilla_args = get_args("--netG sige_fused_spade")
# Arguments for GAN Compression
gc_args = get_args("--netG sub_mobile_spade --config_str 32_32_32_48_32_24_24_32")
# Arguments for GAN Compression with SIGE
sige_gc_args = get_args("--netG sige_fused_sub_mobile_spade --config_str 32_32_32_48_32_24_24_32 --num_sparse_layers 4")


### Create Models

Set device.

In [None]:
import torch
from utils import device_synchronize, get_device

device = get_device()
print("Device:", device)

Define Some helper functions for creating the models.

In [None]:
from torch import nn

from utils import decode_config
from download_helper import get_ckpt_path


def get_model(args, tool="gdown") -> nn.Module:
    netG = args.netG
    config = None
    if netG == "spade":
        from models.spade_generators.spade_generator import SPADEGenerator as Model
    elif netG == "fused_spade":
        from models.spade_generators.fused_spade_generator import FusedSPADEGenerator as Model
    elif netG == "sige_fused_spade":
        from models.spade_generators.sige_fused_spade_generator import SIGEFusedSPADEGenerator as Model
    elif netG == "sub_mobile_spade":
        from models.sub_mobile_spade_generators.sub_mobile_spade_generator import SubMobileSPADEGenerator as Model

        config = decode_config(args.config_str)
    elif netG == "sige_fused_sub_mobile_spade":
        from models.sub_mobile_spade_generators.sige_fused_sub_mobile_spade_generator import SIGEFusedSubMobileSPADEGenerator as Model

        config = decode_config(args.config_str)
    else:
        raise NotImplementedError("Unknown netG: [%s]!!!" % netG)

    model = Model(args, config=config)
    pretrained_path = get_ckpt_path(args)
    model = load_network(model, pretrained_path)
    model = model.to(device)
    model.eval()

    return model


def load_network(net: nn.Module, path: str, verbose: bool = False) -> nn.Module:
    old_state_dict = net.state_dict()
    new_state_dict = torch.load(path)
    state_dict = {}
    for k, v in old_state_dict.items():
        vv = new_state_dict[k]
        if v.shape != vv.shape:
            assert v.dim() == vv.dim() == 1
            assert "param_free_norm" in k
            state_dict[k] = vv[: v.shape[0]]
        else:
            state_dict[k] = vv
    net.load_state_dict(state_dict)
    return net


Build models. **It may take some time to download the model weights. Sometimes the downloading may get stuck. You could change the downloading tool to `torch_hub` with `tool="torch_hub"` and rerun the cell or download the weights manually.**

In [None]:
# Vanilla GauGAN
vanilla_model = get_model(vanilla_args, tool="gdown")

# Vanilla GauGAN with SIGE
sige_vanilla_model = get_model(sige_vanilla_args, tool="gdown")

# GAN Compression
gc_model = get_model(gc_args, tool="gdown")

# GAN Compression with SIGE
sige_gc_model = get_model(sige_gc_args, tool="gdown")

print("Build models successfully!")


### Prepare Data
We have prepared two pairs of user edits in [`./assets`](./assets). Here, we view the ground-truth semantic label ([`assets/gt_label.npy`](assets/gt_label.npy)) and instance map ([`assets/gt_instance.npy`](assets/gt_instance.npy)) as the original map and the synthetic semantic label ([`assets/synthetic_label.npy`](assets/synthetic_label.npy)) and instance map ([`assets/synthetic_instance.npy`](assets/synthetic_instance.npy)) as the edited map. You are free to use any other pairs of data either in our benchmark dataset (see [README.md](./README.md)) or created yourself.

In [None]:
import numpy as np

original_label = torch.from_numpy(np.load("assets/gt_label.npy")).to(device)
original_instance = torch.from_numpy(np.load("assets/gt_instance.npy")).to(device)
edited_label = torch.from_numpy(np.load("assets/synthetic_label.npy")).to(device)
edited_instance = torch.from_numpy(np.load("assets/synthetic_instance.npy")).to(device)

# expand a channel dimension: [H, W] -> [C, H, W]
original_label = original_label.unsqueeze(0)
original_instance = original_instance.unsqueeze(0)
edited_label = edited_label.unsqueeze(0)
edited_instance = edited_instance.unsqueeze(0)


Display the maps.

In [None]:
import ipyplot

from utils import tensor2label

original_label_viz = tensor2label(original_label, vanilla_args.input_nc + 1)
edited_label_viz = tensor2label(edited_label, vanilla_args.input_nc + 1)

ipyplot.plot_images(
    (np.array(original_label_viz), np.array(edited_label_viz)), ("Original", "Edited"), img_width=vanilla_args.crop_size
)


Process the data.

In [None]:
def get_edges(t: torch.Tensor) -> torch.Tensor:
    edge = torch.zeros(t.size(), dtype=torch.uint8, device=device)
    edge[:, :, :, 1:] = edge[:, :, :, 1:] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte())
    edge[:, :, :, :-1] = edge[:, :, :, :-1] | ((t[:, :, :, 1:] != t[:, :, :, :-1]).byte())
    edge[:, :, 1:, :] = edge[:, :, 1:, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte())
    edge[:, :, :-1, :] = edge[:, :, :-1, :] | ((t[:, :, 1:, :] != t[:, :, :-1, :]).byte())
    return edge.float()


label_map = torch.stack((original_label, edited_label), dim=0).long()
instance_map = torch.stack((original_instance, edited_instance), dim=0)

# create one-hot label map
b, c, h, w = label_map.shape
assert c == 1
c = vanilla_args.input_nc
input_label = torch.zeros([b, c, h, w], device=device)
input_semantics = input_label.scatter_(1, label_map, 1.0)

# concatenate instance map if it exists
if not vanilla_args.no_instance:
    instance_edge_map = get_edges(instance_map)
    input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1)


Compute the difference masks. `sige` has some helper functions for this.

In [None]:
from sige.utils import compute_difference_mask, dilate_mask, downsample_mask

difference_mask = compute_difference_mask(input_semantics[0], input_semantics[1], eps=1e-3)
difference_mask = dilate_mask(difference_mask, vanilla_args.mask_dilate_radius)

masks = downsample_mask(
    difference_mask, (vanilla_model.sh, vanilla_model.sw), dilation=vanilla_args.downsample_dilate_radius
)


Visualize the masks.

In [None]:
from PIL import Image


def mask_to_image(mask: torch.Tensor):
    mask_numpy = mask.cpu().numpy()
    image = Image.fromarray(mask_numpy)
    image = image.resize((vanilla_args.crop_size, vanilla_args.crop_size // vanilla_args.aspect_ratio))
    return image


mask_image = mask_to_image(difference_mask)
message = "Sparsity: %.2f%%" % (100 * difference_mask.sum() / difference_mask.numel())
ipyplot.plot_images((np.array(mask_image),), ("Difference Mask",), (message,), img_width=vanilla_args.crop_size)

print("Downsampled Masks")
arrays, labels, messages = [], [], []
for i, (k, v) in enumerate(masks.items()):
    image = mask_to_image(v)
    arrays.append(np.array(image))
    labels.append("Resolution: %dx%d" % (k[0], k[1]))
    messages.append("Sparsity: %.2f%%" % (100 * v.sum() / v.numel()))
ipyplot.plot_images(arrays, labels, messages, img_width=vanilla_args.crop_size)


## Test Models
### Quality Results
Inference.

In [None]:
with torch.no_grad():
    vanilla_result = vanilla_model(input_semantics[1:])
    gc_result = gc_model(input_semantics[1:])

    # SIGE model need a pre-run to determine the data shape cache the original results
    sige_vanilla_model.set_mode("full")
    sige_vanilla_model(input_semantics[:1])
    sige_vanilla_model.set_masks(masks)
    sige_vanilla_model.set_mode("sparse")
    sige_vanilla_result = sige_vanilla_model(input_semantics[1:])

    sige_gc_model.set_mode("full")
    sige_gc_model(input_semantics[:1])
    sige_gc_model.set_masks(masks)
    sige_gc_model.set_mode("sparse")
    sige_gc_result = sige_gc_model(input_semantics[1:])


Visualize the generated images.

In [None]:
from utils import tensor2im

vanilla_image = tensor2im(vanilla_result[0])
sige_vanilla_image = tensor2im(sige_vanilla_result[0])
gc_image = tensor2im(gc_result[0])
sige_gc_image = tensor2im(sige_gc_result[0])


In [None]:
import ipyplot
import numpy as np

ipyplot.plot_images(
    (vanilla_image, sige_vanilla_image, gc_image, sige_gc_image),
    ("Vanilla", "SIGE", "GAN Compression", "GAN Comp.+SIGE"),
    img_width=vanilla_args.crop_size,
)


### Efficiency Results
First, let's profile the MACs of these models.

In [None]:
from torchprofile import profile_macs

# Create some dummy inputs
dummy_inputs = (input_semantics[1:],)

with torch.no_grad():
    vanilla_macs = profile_macs(vanilla_model, dummy_inputs)
    gc_macs = profile_macs(vanilla_model, dummy_inputs)

    # For the SIGE models, we need to first run it in the `full`` mode to cache the results.
    sige_vanilla_model.set_mode("full")
    sige_vanilla_model(*dummy_inputs)
    sige_gc_model.set_mode("full")
    sige_gc_model(*dummy_inputs)

    # We also need to set the difference mask if not set.
    sige_vanilla_model.set_masks(masks)
    sige_gc_model.set_masks(masks)

    # Check to the `profile` mode to profile MACs. This mode is only for the MACs profiling.
    sige_vanilla_model.set_mode("profile")
    sige_vanilla_macs = profile_macs(sige_vanilla_model, dummy_inputs)
    sige_gc_model.set_mode("profile")
    sige_gc_macs = profile_macs(sige_gc_model, dummy_inputs)


In [None]:
print("Vanilla MACs: %.3fG" % (vanilla_macs / 1e9))
print("SIGE MACs: %.3fG" % (sige_vanilla_macs / 1e9))
print("GAN Compression MACs: %.3fG" % (gc_macs / 1e9))
print("GAN Comp.+SIGE MACs: %.3fG" % (sige_gc_macs / 1e9))


SIGE model has a $\sim 18\times$ MACs reduction. With GAN Compression, it could reduce the computation of the vanilla GauGAN by $\sim 50\times$. Now let's measure the latency.

In [None]:
import time

from tqdm import tqdm

# Change these numbers if they are too large for you.
warmup_times = 100
test_times = 100


def measure_latency(model: nn.Module):
    for i in tqdm(range(warmup_times)):
        model(*dummy_inputs)
        device_synchronize(device)
    start_time = time.time()
    for i in tqdm(range(test_times)):
        model(*dummy_inputs)
        device_synchronize(device)
    cost_time = time.time() - start_time
    return cost_time, cost_time / test_times


with torch.no_grad():
    vanilla_cost, vanilla_avg = measure_latency(vanilla_model)
    gc_cost, gc_avg = measure_latency(gc_model)

    # As we have already cached some dummy results for the SIGE model, no need to rerun it in the `full` mode.
    # Check to the `sparse` mode to test the latency.
    sige_vanilla_model.set_mode("sparse")
    sige_vanilla_cost, sige_vanilla_avg = measure_latency(sige_vanilla_model)
    sige_gc_model.set_mode("sparse")
    sige_gc_cost, sige_gc_avg = measure_latency(sige_gc_model)


In [None]:
print("Vanilla: Cost %.2fs Avg %.2fms" % (vanilla_cost, vanilla_avg * 1000))
print("SIGE: Cost %.2fs Avg %.2fms" % (sige_vanilla_cost, sige_vanilla_avg * 1000))
print("GAN Compression: Cost %.2fs Avg %.2fms" % (gc_cost, gc_avg * 1000))
print("GAN Comp.+SIGE: Cost %.2fs Avg %.2fms" % (sige_gc_cost, sige_gc_avg * 1000))
