# Decouple Model Execution from Definition

Hongzheng Chen, Cody Hao Yu, Shuai Zheng


## Background and Motivation
* **Performance**: Large gap between native implementations and optimized models. From Mu's slide
    | Model\Performance (TFLOPs) | HuggingFace | Megatron-LM |
    | :--: | :--: | :--: |
    | BERT | 31 | **43** |
    | GPT2 | 19 | **42** |
* **Productivity**
    * Megatron-LM, DeepSpeed ZeRO-3: Parameter sharding; MiCS: prefetching, buffer pre-allocation
    * Manually modify the model
    * Components are not reusable for models other than Transformers
* **Customizability**
    * Alpa automatically searches for the optimial 3D parallelism
    * Compilation passes are monolithic: Cannot just do some optimizations for specific layers and see results (e.g. shard an op)
    * Optimization are opaque: Hard to locate the issues in the compiler


## Proposal: A Model Scheduling DSL
Decouple model execution from definition
* TVM/Halide: Only consider op-level optimization and only for single machine inference workload

Cover optimizations: (which requires manually changing the models in existing works)
1. Parameter sharding
2. Kernel fusion/injection
3. Gradient checkpointing
4. Memory defragmentation
5. ...

## Demos

### (1) Kernel Injection

Import required packages

In [None]:
import os, sys, copy, time
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import ms # model-scheduling
print(torch.cuda.is_available())

Create a simple MLP model with layer norm

In [None]:
N = 2048

class MLP(nn.Module):

    def __init__(self, dim: int):
        super().__init__()
        self.dense_1 = nn.Linear(dim, dim * 2)
        self.layer_norm = nn.LayerNorm([dim, dim * 2])
        self.activation = nn.ReLU()
        self.dense_2 = nn.Linear(dim * 2, dim)

    def forward(self, x):
        x = self.dense_1(x)
        x = self.layer_norm(x)
        x = self.activation(x)
        x = self.dense_2(x)
        return x

Instanciate the model and create an optimizer for traing

In [None]:
device = "cuda:0"
model = MLP(N).to(device)

Create a default schedule

In [None]:
sch = ms.create_schedule(copy.deepcopy(model))

Currently we use torch.fx to trace the model and generate IR for optimization. We can print out the graph module to see the operators.

In [None]:
print(sch.gm.graph)

Print operators in the module

In [None]:
ops = sch.forward_ops
print(ops)

Replace layer_norm with Apex layer_norm. Just a single line!

In [None]:
from apex.normalization.fused_layer_norm import FusedLayerNorm

sch[ops[1]].replace(FusedLayerNorm, [N, N * 2])

Apply the schedule and regenerate the module

In [None]:
opt_model, optimizer = ms.build(sch)
print(opt_model.graph)

In [None]:
inp = torch.rand(N, N).to(device)
original_output = model(inp)
optimized_output = opt_model(inp)
np.testing.assert_almost_equal(original_output.cpu().detach().numpy(), optimized_output.cpu().detach().numpy(), decimal=5)
print("Results are correct!")

In [None]:
pt_time = []
apex_time = []
for i in range(100):
    inp = torch.rand(N, N).to(device)
    # Test native PyTorch implementation
    start_time = time.time()
    original_output = model(inp)
    pt_time.append((time.time() - start_time) * 1000)
    # Test Apex function
    start_time = time.time()
    optimized_output = opt_model(inp)
    apex_time.append((time.time() - start_time) * 1000)

# plot results
plt.plot(np.arange(100), pt_time, label="pytorch")
plt.plot(np.arange(100), apex_time, label="apex")
plt.legend()
print("Pytorch: {:.4f}ms".format(np.mean(pt_time)))
print("Apex: {:.4f}ms".format(np.mean(apex_time)))

### (1.1) Kernel Fusion

Similarly, we can replace a series of ops with a single fused op/block.

In [None]:
class FusedBlock(nn.Module):

    def __init__(self, dim: int):
        super().__init__()
        self.fc = nn.Linear(dim, dim * 2)
        self.ln = nn.LayerNorm([dim, dim * 2])
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.ln(self.fc(x)))
        return x

sch = ms.create_schedule(copy.deepcopy(model))
ops = sch.forward_ops
sch[ops[0:3]].replace(FusedBlock, N)
print(sch.gm.graph)

### (2) Parameter Sharding



In [None]:
def train(rank, world_size):
    print(f"Running basic MLP example on rank {rank}.")

    # === Model execution schedule ===
    model = MLP(32).cuda(rank)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.002)

    # Create a default schedule
    sch = ms.create_schedule(model, optimizer, world_size, rank)
    
    # Access operators
    ops = sch.forward_ops

    # Partition parameters
    # column sharding for dense_1
    sch[ops[0]].partition(axis=0, param="weight")
    # row sharding for dense_2
    sch[ops[3]].partition(axis=1, param="weight")

    # Partition outputs
    # The result from dense_2 needs aggregation by dim 0
    sch[ops[3]].partition(axis=0)

    # Apply schedule and regenerate module
    model, optimizer = ms.build(sch)

    # Perform a num of iterations of forward/backward
    # and optimizations for the sharded module.
    for i in range(5):
        start_time = time.time()
        inp = torch.rand(16, 32).cuda(rank)
        output = model(inp)
        output.sum().backward()
        optimizer.step()
        elapsed_time = time.time() - start_time
        print(f"Finish step {i}, time: {elapsed_time:.10f}s")

In [None]:
! python3 test.py