<br />
<div align="center">
  <a href="https://deepwok.github.io/">
    <img src="../imgs/deepwok.png" alt="Logo" width="160" height="160">
  </a>

  <h1 align="center">Lab 4 for Advanced Deep Learning Systems (ADLS) - Software Stream</h1>

  <p align="center">
    ELEC70109/EE9-AML3-10/EE9-AO25
    <br />
		Written by
    <a href="https://aaron-zhao123.github.io/">Aaron Zhao, Pedro Gimenes </a>
  </p>
</div>

# General introduction

In this lab, you will learn how to optimize performance for torch based models. 

We will cover the following topics:
1. Automatic performance tuning using existing high-level flows (eg. `torch.compile`), understand its main building blocks and how it works.
2. Manual performance enhancement with techniques such as kernel fusion. 
3. Manual performance enhancement with low-level custom kernels.

# torch.compile

`torch.compile` makes PyTorch models run faster by optimizing the model and the input data. It is a just-in-time compiler that optimizes the model and the input data for the specific hardware. It is a part of the PyTorch ecosystem.

JIT compilation is a technique that converts the Python code into machine code at runtime. This technique is used to improve the performance of the Python code. JIT compilation is used in many programming languages, including Python, Java, and C#. JIT compiler typically continuously analyses the code being executed and identifies parts of the code where the speedup gained from compilation or recompilation would outweigh the overhead of compiling that code.


We try to use `torch.compile` to optimize the performance of some arbitrary functions and also a model. The usage of `torch.compile` is very simple. You just need to add `torch.compile` as a decorator before the function or model you want to optimize, or you wrap it on top of the function or model.

We first instantiate the helper functions below:


In [17]:
import torch
import time
from chop.models import get_model
from chop.dataset import get_dataset_info

def timed_gpu(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

def timed_cpu(fn):
    start = time.time()
    result = fn()
    return result, time.time() - start

def get_data():
    return torch.randn(128, 3, 224, 224)

def time_model(fn, n=1000, device='cpu'):
    times = []
    data = get_data().to(device)
    for _ in range(n):
        if device == 'cpu':
            _, t = timed_cpu(lambda: fn(data.cpu()))
        else:
            _, t = timed_gpu(lambda: fn(data))
        times.append(t)
    avg_time = sum(times) / len(times)
    return avg_time

Now, define the neural network. We're using a model which can be used to perform image classification on the ImageNet dataset.

In [18]:
task = "classification"
model = get_model("resnet18", pretrained=True, num_classes=1000)
image = torch.randn(64, 3, 224, 224)

opt_model = torch.compile(model)

Now, we can inspect the runtime of the model. 

In [20]:
device = "cpu"
n = 5

model.to(device)
opt_model.to(device)
avg_t = time_model(model, n=n, device=device)
opt_avg_t = time_model(opt_model, n=n, device=device)
print(f"Original model: {avg_t:.4f} s")
print(f"Optimized model: {opt_avg_t:.4f} s")

Original model: 2.6594 s
Optimized model: 3.1050 s


We can see, very likely, maybe it varies on your machine, that the runtime of the optimized model is actually slower. There are a few reasons for this, but these are left as an exercise for the reader to investigate.

The core idea about `torch.compile` is that it is an automatic optimization tool. There are actually three main building blocks in `torch.compile`:

- TorchDynamo: `torch._dynamo` actually captures the PyTorch graph through CPython (https://github.com/pytorch/pytorch/tree/main/torch/_dynamo).
- TorchInductor: `torch._inductor` (https://github.com/pytorch/pytorch/tree/main/torch/_inductor) can be seen as the backend of `torch.compile`, which is responsible for providing the kernel implementation of different operators in the PyTorch graph, it actually mainly uses OpenAI's Triton for implementing these kernels (https://github.com/pytorch/pytorch/tree/main/torch/_inductor/kernel).
- AOT Autograd: This allows you to capture the whole graph, including the backward pass, ahead of time!

TorchInductor actually also makes use of the `fx.graph` to pattern match code, as shown [here](https://github.com/pytorch/pytorch/tree/main/torch/_inductor/fx_passes).

# Using a fused kernel



Fused kernels are advantageous as they can reduce both the number of memory accesses and number of kernel launches. This can be particularly useful when you have a lot of small operations that can be fused together.

The following example demonstrates how to use a fused kernel to optimize the performance of the scaled dot product attention (SDPA). The two different implementations are shown below:

In [31]:
import math
import torch
import torch.nn.functional as F

class ScaledDotProductAttention(torch.nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, query, key, value):
        scale_factor = 1 / math.sqrt(query.size(-1))
        score = query @ key.transpose(-2, -1) / scale_factor
        attn = F.softmax(score, -1)
        context = attn @ value
        return context

class ScaledDotProductAttentionFused(torch.nn.Module):
    def forward(self, query, key, value):
        return F.scaled_dot_product_attention(query, key, value)

The first implementation makes use of a number of functions to compute the scaled dot product attention. The second implementation uses a fused kernel provided by Pytorch to compute the scaled dot product attention.

Now we can do a very simple test to check their functional equivalence.

In [32]:
device = "cpu"
query = torch.ones(32, 8, 128, 64, dtype=torch.float16, device=device)
key = torch.ones(32, 8, 128, 64, dtype=torch.float16, device=device)
value = torch.ones(32, 8, 128, 64, dtype=torch.float16, device=device)

y1 = ScaledDotProductAttention()(query, key, value)
y2 = ScaledDotProductAttentionFused()(query, key, value)
print(y1[0,0,0,0], y2[0,0,0,0])

tensor(1., dtype=torch.float16) tensor(1., dtype=torch.float16)
