# PP simulation
Currently, the covered pipeline scheduels are: GPipe, 1F1B, interleaved 1F1B, Hanayo (1-wave and 2-wave), and ZB-H1. 

## Notation
$D$: number of devices

$d_i$: device $i$. Specifically, $d_{-1}$ means the last device.

$MBS$: number of microbatches between two flushes.

$T_F(d_i)$: forward time of a microbatch on device $d_i$. If constant among devices, it can be abbreviated as $T_F$.

$T_B(d_i)$: backward time of a microbatch on device $d_i$. For Zero-Bubble schedules, it only counts activation backward time. If constant among devices, it can be abbreviated as $T_B$. 

$T_W(d_i)$: (only in Zero-Bubble) weight gradient backward time of a microbatch on device $d_i$. If constant among devices, it can be abbreviated as $T_W$.  

$T_C$: P2P send/recv communication time between two devices. We assume $T_C$ is constant among all possible device pairs.

$T_{tot}$: total execution time of the pipeline. From the start of first mb among all devices till the end of last mb among all devices. (Such definition does not favour the async optimizer stepping idea in Zero-Bubble paper, which focuses on minimizing the longest time on any devices from start of first mb to end of last mb). 

$r_{bub}$: bubble ratio of the pipeline, defined as $1-\frac{\sum_{d_i} MBS * [T_F(d_i) + T_B(d_i) + T_W(d_i)]}{D * T_{tot}}$, if constant $T_F, T_B, T_W$, then it is simplified to $1-\frac{mbs * [T_F + T_B + T_W]}{T_{tot}}$ 



## Disclaimer
We assume $D$ divides $MBS$ WLOG. Specifically for Hanayo, it is restricted to $MBS=D$ (at least they did not show the more genral case in their paper). For fair comparision, we iterate the Hanayo scheme $MBS/D$ times. 

We do not list $T_B$ and $T_W$ of Zero-Bubble separately. We will provide a $T_B$ for all schedules, and assume activation backward time equals to weight backward time. 

Interleaved/Wave-like schedules further cut microbatch into smaller chunks, so the $T_F, T_B, T_W$ of each chunk should be divided by number of chunks (equally as we assume).

Default settings if not specified: $MBS=2D$, async send/recv.

In [None]:
from pipeline import SystemConfig, Pipeline, GpipePipeline, OneFOneBPipeline, Interleaved1F1BPipeline, Hanayo1F1BPipeline, ZBH1Pipeline
import matplotlib.pyplot as plt
import numpy as np
from typing import List

### Weak scaling $D$ 
Configuration: $D = 2,4,8,16,32, MBS = 2*D, T_F = 400/D, T_B = 800 / D, T_C=0,5,10$


In [None]:
names = ['Gpipe', '1F1B', '1F1B vpp=2', 'Hanayo wave=1', 'Hanayo wave=2', 'ZBH1']
cls_list = [GpipePipeline, OneFOneBPipeline, Interleaved1F1BPipeline, Hanayo1F1BPipeline, Hanayo1F1BPipeline, ZBH1Pipeline]

def weak_scale_D(T_C=0, print_schedule=False):
    D = [2, 4, 8, 16, 32]
    repeat = 2 # MBS/D
    sim_res = {}
    for name in names:
        sim_res[name] = {"total_time": [], "bubble": []}
    
    for d in D:
        T_F = 400 / d
        T_B = 800 / d
        mbs = repeat * d
        configs =   [
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=4),
                    SystemConfig(T_F=T_F, T_B=T_B/2, T_W=T_B/2, T_C=T_C, num_devices=d, num_microbatches=mbs),
                    ]
        for name, config, cls in zip(names, configs, cls_list):
            pipeline: Pipeline = cls(config)
            # print(pipeline.sys_config.T_F, pipeline.sys_config.T_B, pipeline.sys_config.T_W)
            tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
            if print_schedule:
                pipeline.print_schedule(name)
            sim_res[name]["total_time"].append(tot_time)
            sim_res[name]["bubble"].append(bubble)
        
    fig, (ax1,ax2) = plt.subplots(1,2, layout='tight', figsize=(10,5))
    for name in names:
        # print(name)
        ax1.plot(D, sim_res[name]["total_time"], label=name)
        ax2.plot(D, sim_res[name]["bubble"], label=name)
    ax1.set_xlabel('Number of devices')
    ax2.set_xlabel('Number of devices')
    ax1.set_ylabel('Total Time')
    ax2.set_ylabel('Bubble Ratio')
    ax1.set_xticks(D)
    ax2.set_xticks(D)
    
    ax1.legend()
    ax2.legend()
    fig.suptitle(f'Weak scaling number of devices, MBS=2D, T_C={T_C}')
    plt.show()
    
weak_scale_D(0)
weak_scale_D(4)
weak_scale_D(16)

### Strong scaling $D$ 
Configuration: $D = 2,4,8,16,32, MBS = 32, T_F = 400/D, T_B = 800 / D, T_C=0,5,10$


In [None]:
def strong_scale_D(T_C=0, print_schedule=False):
    D = [2, 4, 8, 16, 32]
    mbs = 32
    sim_res = {}
    for name in names:
        sim_res[name] = {"total_time": [], "bubble": []}
    
    for d in D:
        T_F = 400 / d
        T_B = 800 / d
        repeat = mbs / d
        configs =   [
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                    SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=4),
                    SystemConfig(T_F=T_F, T_B=T_B/2, T_W=T_B/2, T_C=T_C, num_devices=d, num_microbatches=mbs),
                    ]
        for name, config, cls in zip(names, configs, cls_list):
            pipeline: Pipeline = cls(config)
            try:
                tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                if print_schedule:
                    pipeline.print_schedule(name)
            except:
                tot_time = None
                bubble = None
                
            sim_res[name]["total_time"].append(tot_time)
            sim_res[name]["bubble"].append(bubble)
        
    fig, (ax1,ax2) = plt.subplots(1,2, layout='tight', figsize=(10,5))
    for name in names:
        # print(name)
        ax1.plot(D, sim_res[name]["total_time"], label=name)
        ax2.plot(D, sim_res[name]["bubble"], label=name)
    ax1.set_xlabel('Number of devices')
    ax2.set_xlabel('Number of devices')
    ax1.set_ylabel('Total Time')
    ax2.set_ylabel('Bubble Ratio')
    ax1.set_xticks(D)
    ax2.set_xticks(D)
    ax1.set_yscale('log')
    
    ax1.legend()
    ax2.legend()
    fig.suptitle(f'Strong scaling number of devices, MBS=32, T_C={T_C}')
    plt.show()
    
strong_scale_D(0)
strong_scale_D(4)
strong_scale_D(16)

### Recompute
We define the recompute ratio $r_{recomp}\in [0,1]$, where $0$ means no recompute and $1$ means full recompute. The corresponding backward time becomes $T_B=(2+r)T_F$.

Configuration: $D = 2,4,8,16,32, MBS = 2D, T_F = 400/D, T_B = 800 / D, T_C=0,5,10$


In [None]:
from collections import defaultdict


def recompute_bubble(d=2, print_schedule=False):
    T_C_list = [0, 4, 16]
    repeat = 2
    mbs = repeat*d
    r_recomp_list = np.arange(0, 1.2, 0.2)
    sim_res = {}
    for name in names:
        sim_res[name] = defaultdict(list)
    
    for T_C in T_C_list:
        for r_recomp in r_recomp_list:
            T_F = 400 / d
            T_B = 400 / d * (2 + r_recomp)
            zb_T_B = 400 / d * (1 + r_recomp)
            zb_T_W = 400 / d
            configs =   [
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=4),
                        SystemConfig(T_F=T_F, T_B=zb_T_B, T_W=zb_T_W, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        ]
            for name, config, cls in zip(names, configs, cls_list):
                pipeline: Pipeline = cls(config)
                try:
                    tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                    if print_schedule:
                        pipeline.print_schedule(name)
                except:
                    tot_time = None
                    bubble = None
                    
                sim_res[name][T_C].append(bubble)
    
    fig_cols = len(T_C_list)
    fig, axes = plt.subplots(1,fig_cols, layout='tight', figsize=(5 * fig_cols,5))
    for name in names:
        # print(name)
        for i, T_C in enumerate(T_C_list):
            axes[i].plot(r_recomp_list, sim_res[name][T_C], label=name)
            axes[i].set_xlabel('Recomp ratio')
            axes[i].set_ylabel('Bubble Ratio')
            axes[i].set_title(f'T_C={T_C}')
            axes[i].set_xticks(r_recomp_list)
            axes[i].legend()
    
    fig.suptitle(f'Recompute vs. Bubble, MBS=2D, D={d}')
    plt.show()
    
recompute_bubble(2)
recompute_bubble(4)
recompute_bubble(8)
recompute_bubble(16)
recompute_bubble(32)

### MBS
Configuration: $D = 2,4,8,16,32, MBS/D = 1,2,4,8, T_F = 400/D, T_B = 800 / D, T_C=0,5,10$

In [None]:
def mbs_bubble(d=2, print_schedule=False):
    T_C_list = [0, 4, 16]
    repeat_list = [1,2,4,8]
    sim_res = {}
    for name in names:
        sim_res[name] = defaultdict(list)
    
    for T_C in T_C_list:
        for repeat in repeat_list:
            T_F = 400 / d
            T_B = 800 / d
            mbs = repeat * d
            configs =   [
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=4),
                        SystemConfig(T_F=T_F, T_B=T_B/2, T_W=T_B/2, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        ]
            for name, config, cls in zip(names, configs, cls_list):
                pipeline: Pipeline = cls(config)
                try:
                    tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                    if print_schedule:
                        pipeline.print_schedule(name)
                except:
                    tot_time = None
                    bubble = None
                    
                sim_res[name][T_C].append(bubble)
    
    fig_cols = len(T_C_list)
    fig, axes = plt.subplots(1,fig_cols, layout='tight', figsize=(5 * fig_cols,5))
    for name in names:
        # print(name)
        for i, T_C in enumerate(T_C_list):
            axes[i].plot(repeat_list, sim_res[name][T_C], label=name)
            axes[i].set_xlabel('MBS/Devices')
            axes[i].set_ylabel('Bubble Ratio')
            axes[i].set_title(f'T_C={T_C}')
            axes[i].set_xticks(repeat_list)
            axes[i].legend()            
    
    fig.suptitle(f'MBS vs. Bubble, D={d}')
    plt.show()
    
mbs_bubble(2)
mbs_bubble(4)
mbs_bubble(8)
mbs_bubble(16)

### Waves, VPP
Configuration: $D = 2,4,8,16,32, MBS/D = 2, T_F = 400/D, T_B = 800 / D, T_C=0,5,10, chunks = 2,4,6,8$

For interleaved 1F1B, $vpp=chunks$. For Hanayo, $wave = chunks/2$

In [None]:
def chunk_bubble(d=2, print_schedule=False):
    T_C_list = [0, 4, 16]
    chunk_list = [2,4,6,8]
    repeat = 2
    mbs = repeat*d
    sim_res = {}
    
    cls_list = [Interleaved1F1BPipeline, Hanayo1F1BPipeline]
    names = ['Interleaved 1F1B', 'Hanayo']
    for name in names:
        sim_res[name] = defaultdict(list)
    
    for T_C in T_C_list:
        for chunk in chunk_list:
            T_F = 400 / d
            T_B = 800 / d
            mbs = repeat * d
            configs =   [
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=chunk),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=chunk),
                        ]
            for name, config, cls in zip(names, configs, cls_list):
                pipeline: Pipeline = cls(config)
                try:
                    tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                    if print_schedule:
                        pipeline.print_schedule(name)
                except:
                    tot_time = None
                    bubble = None
                    
                sim_res[name][T_C].append(bubble)
    
    fig_cols = len(T_C_list)
    fig, axes = plt.subplots(1,fig_cols, layout='tight', figsize=(5 * fig_cols,5))
    for name in names:
        # print(name)
        for i, T_C in enumerate(T_C_list):
            axes[i].plot(chunk_list, sim_res[name][T_C], label=name)
            axes[i].set_xlabel('chunks')
            axes[i].set_ylabel('Bubble Ratio')
            axes[i].set_title(f'T_C={T_C}')
            axes[i].set_xticks(chunk_list)
            axes[i].legend()            
    
    fig.suptitle(f'Chunks vs. Bubble, D={d}')
    plt.show()
    
chunk_bubble(2)
chunk_bubble(4)
chunk_bubble(8)
chunk_bubble(16)
chunk_bubble(32)

### T_B/T_W in Zero-Bubble
Configuration: $D = 2,4,8,16,32, MBS/D = 1,2,4, T_F = 400/D, T_W+T_B = 800 / D, T_C=0,5,10, \frac{T_B}{T_B+T_W}=0.2,0.4,0.5,0.6,0.8$


In [None]:
def T_B_T_C_bubble(mbs_per_d=1, print_schedule=False):
    T_C_list = [0, 4, 16]
    D = [2, 4, 8, 16, 32]
    T_B_ratios = [0.2, 0.4, 0.5, 0.6, 0.8]
    repeat = mbs_per_d
    sim_res = {}
    for T_C in T_C_list:
        sim_res[T_C] = defaultdict(list)
    
    for T_C in T_C_list:
        for d in D:
            T_F = 400 / d
            mbs = repeat * d
            configs = []
            for T_B_ratio in T_B_ratios:
                T_B = 800 * T_B_ratio / d
                T_W = 800 * (1 - T_B_ratio) / d
                configs.append(SystemConfig(T_F=T_F, T_B=T_B, T_W=T_W, T_C=T_C, num_devices=d, num_microbatches=mbs))
            for T_B_ratio, config in zip(T_B_ratios, configs):
                pipeline: Pipeline = ZBH1Pipeline(config)
                try:
                    tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                    if print_schedule:
                        pipeline.print_schedule()
                except:
                    tot_time = None
                    bubble = None
                
                sim_res[T_C][d].append(bubble)
    
    fig_cols = len(T_C_list)
    fig, axes = plt.subplots(1,fig_cols, layout='tight', figsize=(5 * fig_cols,5))
    for d in D:
        # print(name)
        for i, T_C in enumerate(T_C_list):
            axes[i].plot(T_B_ratios, sim_res[T_C][d], label=f'D={d}')
            axes[i].set_xlabel('$T_B/(T_B+T_W)$')
            axes[i].set_ylabel('Bubble Ratio')
            axes[i].set_title(f'T_C={T_C}')
            axes[i].set_xticks(T_B_ratios)
            axes[i].legend()
    
    fig.suptitle(f'$T_B/(T_B+T_W)$ vs. Bubble, MBS={mbs_per_d}D')
    plt.show()
    
T_B_T_C_bubble(1)
T_B_T_C_bubble(2)
T_B_T_C_bubble(4)
T_B_T_C_bubble(8)

### Bubble sensitivity: $T_C$
Configuration: $D = 2,4,8,16,32, MBS/D = 1,2,4, T_F = 400/D, T_B = 800 / D, T_C=0,2,4,8,16,32$


In [None]:
def sensitivity_bubble_T_C(d=2, print_schedule=False):
    T_C_list = [0,2,4,8,16,32]
    repeat_list = [1,2,4]
    sim_res = {}
    for name in names:
        sim_res[name] = defaultdict(list)
    
    for repeat in repeat_list:
        for T_C in T_C_list:
            T_F = 400 / d
            T_B = 800 / d
            mbs = repeat * d
            configs =   [
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=4),
                        SystemConfig(T_F=T_F, T_B=T_B/2, T_W=T_B/2, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        ]
            for name, config, cls in zip(names, configs, cls_list):
                pipeline: Pipeline = cls(config)
                try:
                    tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                    if print_schedule:
                        pipeline.print_schedule(name)
                except:
                    tot_time = None
                    bubble = None
                
                sim_res[name][repeat].append(bubble)
    
    fig_cols = len(repeat_list)
    fig, axes = plt.subplots(1,fig_cols, layout='tight', figsize=(5 * fig_cols,5))
    for name in names:
        # print(name)
        for i, repeat in enumerate(repeat_list):
            axes[i].plot(T_C_list, sim_res[name][repeat], label=name)
            axes[i].set_xlabel('T_C')
            axes[i].set_ylabel('Bubble Ratio')
            axes[i].set_title(f'repeat={repeat}')
            axes[i].set_xticks(T_C_list)
            axes[i].legend()
    
    fig.suptitle(f'Bubble Sensitivity vs. T_C, D={d}')
    plt.show()
    
sensitivity_bubble_T_C(2)
sensitivity_bubble_T_C(4)
sensitivity_bubble_T_C(8)
sensitivity_bubble_T_C(16)
sensitivity_bubble_T_C(32)

### Imbalanced workloads

We only set inbalanced workloads on first and last stage, and denote the extra time in forward as $T_{ext}$. So, $T_F(d_0)+=T_{ext}, T_F(d_{-1})+=T_{ext}, T_B(d_0)+=2T_{ext}, T_B(d_{-1})+=2T_{ext}.

Configuration: $D = 2,4,8,16,32, MBS/D = 2, T_F = 400/D, T_B = 800 / D, T_{ext}/T_F = 0,0.1,0.2,0.3,0.4,0.5,T_C=0,5,10$


In [None]:
def imbalanced_bubble(d=2, print_schedule=False):
    T_C_list = [0, 4, 16]
    repeat = 2
    T_ext_ratios = np.arange(0, 0.5, 0.1)
    # T_F_list = [400 / d for _ in range(d)]
    # T_B_list = [800 / d for _ in range(d)]
    sim_res = {}
    for name in names:
        sim_res[name] = defaultdict(list)
    
    for T_C in T_C_list:
        for T_ext_ratio in T_ext_ratios:
            T_F = [400 * (1+T_ext_ratio) if (i == 0 or i == d-1) else 400 for i in range(d)]
            T_B = [800 * (1+T_ext_ratio) if (i == 0 or i == d-1) else 800 for i in range(d)]
            zb_T_B = [400 * (1+2*T_ext_ratio) if (i == 0 or i == d-1) else 400 for i in range(d)]
            zb_T_W = [400 for i in range(d)]
            mbs = repeat * d
            configs =   [
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=2),
                        SystemConfig(T_F=T_F, T_B=T_B, T_C=T_C, num_devices=d, num_microbatches=mbs, num_chunks=4),
                        SystemConfig(T_F=T_F, T_B=zb_T_B, T_W=zb_T_W, T_C=T_C, num_devices=d, num_microbatches=mbs),
                        ]
            for name, config, cls in zip(names, configs, cls_list):
                pipeline: Pipeline = cls(config)
                try:
                    tot_time, bubble = pipeline.compute_schedule_time_and_bubble()
                    if print_schedule:
                        pipeline.print_schedule(name)
                except:
                    tot_time = None
                    bubble = None
                    
                sim_res[name][T_C].append(bubble)
    
    fig_cols = len(T_C_list)
    fig, axes = plt.subplots(1,fig_cols, layout='tight', figsize=(5 * fig_cols,5))
    for name in names:
        # print(name)
        for i, T_C in enumerate(T_C_list):
            axes[i].plot(T_ext_ratios, sim_res[name][T_C], label=name)
            axes[i].set_xlabel(r'$T_{ext}/T_F$')
            axes[i].set_ylabel('Bubble Ratio')
            axes[i].set_title(f'T_C={T_C}')
            axes[i].set_xticks(T_ext_ratios)
            axes[i].legend()            
    
    fig.suptitle(f'Imbalanced workloads vs. Bubble, D={d}')
    plt.show()
    
imbalanced_bubble(2)
imbalanced_bubble(4, print_schedule=True)
imbalanced_bubble(8)
imbalanced_bubble(16)
imbalanced_bubble(32)
    