# SIGE Tiling-based Sparse Convolution Usage Example
In this notebook, we will show how to implement a minimal tiling-based sparse convolution with SIGE.

## Setup
1. Install [PyTorch](https://pytorch.org).

In [None]:
!pip install torch

2. Install [SIGE](https://github.com/lmxyy/sige-dev/) and other dependencies. **(This may take several minute.)**

In [None]:
!pip install sige 
!pip install torchprofile

## Get Started

In [None]:
import argparse
import os

import numpy as np
import torch
from IPython.display import display
from PIL import Image
from torchprofile import profile_macs

from sige.nn import Gather, Scatter, SIGEConv2d, SIGEModel, SIGEModule


### Get Inputs
#### Set the test device and generate the original input.

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Device:", device)

original_input = torch.randn((1, 16, 256, 256), device=device)


#### Get the difference mask.

In [None]:
from torch.hub import download_url_to_file

if not os.path.exists("assets/mask.npy"):
    os.makedirs("assets", exist_ok=True)
    download_url_to_file("https://github.com/lmxyy/sige/blob/main/assets/mask.npy?raw=true", "assets/mask.npy")
mask = np.load("assets/mask.npy")

mask_image = Image.fromarray(~mask)
mask = torch.from_numpy(mask).to(device)
display(mask_image)
print("Difference Mask Sparsity: %.2f%%" % (mask.sum() / mask.numel() * 100))


#### Generate the edited input according to the difference mask.

In [None]:
edited_input = original_input + torch.randn((1, 16, 256, 256), device=device) * mask[None, None]


### Get the Model
We first define a module consisting of a single `Gather`, 3x3 conv and `Scatter`. 

In [None]:
class ExampleModule(SIGEModule):
    def __init__(self):
        super(ExampleModule, self).__init__()
        self.conv = SIGEConv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True)
        self.gather = Gather(self.conv, block_size=6)
        self.scatter = Scatter(self.gather)

    def forward(self, x):
        x = self.gather(x)
        x = self.conv(x)
        x = self.scatter(x)
        return x


`SIGEModule` is a `nn.Module` wrapper that supports inference with three different modes:
* `full`: The original inference. For the example above, the full mode will just perform the stardard $3 \times 3$ convolution.
* `sparse`: The tiling-based sparse convolution. 
* `profile`: This mode is only used when profiling the MACs of the tiling-based convolution.
It also supports setting the difference mask.

`Gather`, `Scatter` and `SIGEConv2d` are also `SIGEModule`. Specifically,
* `Gather` initialization requires the paired convolution and the sparse block size. During `full` inference, it will just record the input shape. During `sparse` inference, it will gather the active blocks according to the `active_indices` reduced from the difference mask. During `profile` inference, it will just create a dummy tensor to symbolicly track the computation graph for MACs profiling.
* `Scatter` initialization requires the paired `Gather` module. During `full` inference, it will just cache the input tensor. During `sparse` inference, it will scatter the input blocks to the cached tensor according the `active_indices` in the paired `Gather`. During `profile` inference, it will just create a dummy tensor to symbolicly track the computation graph for MACs profiling.
* `SIGEConv2d` is just a wrapper of `nn.Conv2d`. During `full` inference, it performs as the standard convolution. During `sparse` or `profile` inference, the `padding` will be 0 as the gathered blocks are already padded.

Then we wrap the `ExampleModule` into a `SIGEModel`:

In [None]:
class ExampleModel(SIGEModel):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.example_module = ExampleModule()

    def forward(self, x: torch.Tensor):
        return self.example_module(x)


`SIGEModel` is a class to wrap the toppest `nn.Module`. It supports setting difference masks and the inference mode to its children `SIGEModule`.

Then we can get the model.

In [None]:
model = ExampleModel().to(device)
model.eval()


### Test the Model
First, let's get the results of the full model.

In [None]:
with torch.no_grad():
    model.set_mode("full")
    std_output = model(edited_input)  # for further comparisons
    full_macs = profile_macs(model, (edited_input,))


Let's try the sparse  inference with SIGE. We first need to cache the original input results:

In [None]:
with torch.no_grad():
    model.set_mode("full")
    original_output = model(original_input)


Then we could try the sparse inference

In [None]:
with torch.no_grad():
    model.set_mode("sparse")
    model.set_masks({(256, 256): mask})
    sige_output = model(edited_input)
    model.set_mode("profile")
    sige_macs = profile_macs(model, (edited_input,))


`set_masks` takes a `Dict` object as input. The key is the resolution tuple and the value is the 2D mask tensor. Remember that `SIGEModel` will broadcast the masks to all its children `SIGEModule`, including `Gather`. `Gather` will reduce the mask of the corresponding resolution to `active_indices`. 

Now let's compare the results of the stardard convolution and SIGE sparse convolution.

In [None]:
print("Max Error: %.6f" % abs(std_output - sige_output).max().item())
print("Masked Region: %.2f%%" % (mask.sum() / mask.numel() * 100).item())
print("Full MACs: %.2fM" % (full_macs / 1e6))
print("SIGE MACs: %.2fM" % (sige_macs / 1e6))


SIGE reduces $5.23\times$ computations in this example. Please refer to our [diffusion model](https://github.com/lmxyy/sige/tree/main/diffusion) and [GauGAN](https://github.com/lmxyy/sige-dev/tree/main/gaugan) benchmark for more usage examples.