compare each pp schedule: simple case, T_F=1, T_B=2, M = inf, i

In [20]:
# set PYTHONPATH
import sys
sys.path.append('/users/ctianche/Megatron-LM/')

from typing import Type
from IPython.display import display
from megatron.core.pipeline_parallel.cdc_scheduler.pp_generator.pipeline import GpipePipeline, Hanayo1F1BPipeline, HeuristicLoopZBPipeline, HeuristicWaveZBPipeline, HeuristicWaveZBPipelineV2, HeuristicZBVPipeline, Interleaved1F1BPipeline, OneFOneBPipeline, Pipeline, ZBH1Pipeline, HeuristicZBUDPipeline
from megatron.core.pipeline_parallel.cdc_scheduler.pp_generator.pipeline_config import SystemConfig
from megatron.core.pipeline_parallel.cdc_scheduler.pp_generator.util import generate_comm_mat


def test_pipeline(
    PipelineClass: Type[Pipeline],
    sys_config: SystemConfig,
) -> None:
    pipeline = PipelineClass(sys_config)
    pipeline.schedule()
    pipeline.solve_dependencies()
    print(f"{PipelineClass.__name__} runtime: {pipeline.get_schedule_time(device_wise=True)}, bubble: {pipeline.get_bubble_ratio(device_wise=True):.2f}")
    display(pipeline.print_schedule(include_info=True if sys_config.num_devices < 10 else False))

def test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=200, T_W=200, comm_matrix=None):
    num_chunks = 2
    T_F_chunk = T_F / num_chunks
    T_B_chunk = T_B / num_chunks
    T_W_chunk = T_W / num_chunks
    M_F = 2
    M_B = -0.3
    M_W = -1.7
    mem_factor = [1.25]
    M_limit = num_dev * num_chunks * M_F

    if comm_matrix is None:
        comm_matrix = generate_comm_mat(1, num_dev, 0, 0)
    # gpipe, 1f1b,
    sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F,
        T_B=T_B + T_W,
        T_C=comm_matrix,
    )
    # iv1f1b, hanayo
    interleaved_sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F_chunk,
        T_B=T_B_chunk + T_W_chunk,
        T_C=comm_matrix,
        num_chunks=num_chunks,
    )

    zbh1_sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F,
        T_B=T_B,
        T_W=T_W,
        M_F=M_F,
        M_B=M_B,
        M_W=M_W,
        M_Limit=M_limit / 2,
        T_C=comm_matrix,
    )
    heur_zbv_sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F_chunk,
        T_B=T_B_chunk,
        T_W=T_W_chunk,
        T_C=comm_matrix,
        num_chunks=num_chunks,
        M_F=M_F,
        M_B=M_B,
        M_W=M_W,
        M_Limit=M_limit,
    )
    additional_cfg_heur = [SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F_chunk,
        T_B=T_B_chunk,
        T_W=T_W_chunk,
        T_C=comm_matrix,
        num_chunks=num_chunks,
        M_F=M_F,
        M_B=M_B,
        M_W=M_W,
        M_Limit=M_limit * f,
    ) for f in mem_factor
    ]
    # test_pipeline(OneFOneBPipeline, sys_config)
    # test_pipeline(GpipePipeline, sys_config)
    # test_pipeline(Interleaved1F1BPipeline, interleaved_sys_config)
    # test_pipeline(Hanayo1F1BPipeline, interleaved_sys_config)
    # test_pipeline(ZBH1Pipeline, zbh1_sys_config)
    # test_pipeline(HeuristicZBUDPipeline,zbh1_sys_config)
    # test_pipeline(HeuristicLoopZBPipeline, heur_zbv_sys_config)
    try:
        zbv = HeuristicZBVPipeline(heur_zbv_sys_config)
        zbv.schedule()
        zbv.solve_dependencies()
    except Exception as e:
        zbv = None
    if zbv is not None:
        print(f"ZBV runtime: {zbv.get_schedule_time(device_wise=True)}, bubble: {zbv.get_bubble_ratio(device_wise=True):.2f}")
        zbv.print_schedule(include_info=True)
    from copy import deepcopy
    # test_pipeline(HeuristicWaveZBPipeline, heur_zbv_sys_config)
    # for oneBoneW in [True, False]:
    #     for teardonw in [True, False]:
    #         print(f'Heuristic: oneBoneW: {oneBoneW}, tear_down: {teardonw}')
    #         heur_zbv_sys_config.oneBoneW = oneBoneW
    #         heur_zbv_sys_config.teardown = teardonw
    #         test_pipeline(HeuristicWaveZBPipelineV2, heur_zbv_sys_config)
    # test_pipeline(HeuristicWaveZBPipelineV2, heur_zbv_sys_config)
    # for factor, cfg in zip(mem_factor, additional_cfg_heur):
    #     print(f'Heuristic: mem_factor: {factor}')
    #     test_pipeline(HeuristicWaveZBPipeline, cfg)
    #     test_pipeline(HeuristicWaveZBPipelineV2, cfg)
    candidates = []
    try:
        wave_v1 = HeuristicWaveZBPipeline(heur_zbv_sys_config)
        wave_v1.schedule()
        wave_v1.solve_dependencies()
    except Exception as e:
        wave_v1 = None
    if wave_v1 is not None:
        candidates.append(wave_v1)
    
    for aux_1b1w in [True,False]:
        for aux_tear_down_opt in [True,False]:
            for aux_w_if_b_mem_limited in [True,False]:
                cfg = deepcopy(heur_zbv_sys_config)
                cfg.aux_1b1w = aux_1b1w
                cfg.aux_tear_down_opt = aux_tear_down_opt
                cfg.aux_w_if_b_mem_limited = aux_w_if_b_mem_limited
                try:
                    pipe = HeuristicWaveZBPipelineV2(cfg)
                    pipe.schedule()
                    pipe.solve_dependencies()
                except Exception as e:
                    pipe = None
                if pipe is not None:
                    candidates.append(pipe)
    # select the best pipeline
    best_pipeline = None
    best_time = float('inf')
    for pipe in candidates:
        time = pipe.get_schedule_time(device_wise=True)
        if time < best_time:
            best_time = time
            best_pipeline = pipe
    assert best_pipeline is not None
    print(f"Best pipeline: {best_pipeline.__class__.__name__}, runtime: {best_time}, bubble: {best_pipeline.get_bubble_ratio(device_wise=True):.2f}")
    best_pipeline.print_schedule(include_info=True)
                
                

In [21]:
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=200, T_B=200, T_W=200, comm_matrix=generate_comm_mat(2, 2, 0, 0))

ZBV runtime: 4800.0, bubble: 0.00
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 4900.0, bubble: 0.02


In [22]:
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=200, T_B=200, T_W=200, comm_matrix=generate_comm_mat(2, 2, 0, 100))

ZBV runtime: 5400.0, bubble: 0.11
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 5500.0, bubble: 0.13


In [23]:
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=200, T_B=180, T_W=100, comm_matrix=generate_comm_mat(2, 2, 0, 100))

ZBV runtime: 4670.0, bubble: 0.18
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 4870.0, bubble: 0.21


In [24]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=200, T_W=140, comm_matrix=generate_comm_mat(2, 16, 0, 400))

ZBV runtime: 60210.0, bubble: 0.43
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 42450.0, bubble: 0.19


In [25]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=250, T_W=150, comm_matrix=generate_comm_mat(4, 8, 0, 100))

ZBV runtime: 45150.0, bubble: 0.15
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 47600.0, bubble: 0.19


In [26]:
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=180, T_B=220, T_W=200, comm_matrix=generate_comm_mat(1, 4, 0, 0))

ZBV runtime: 4860.0, bubble: 0.01
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 5130.0, bubble: 0.06


In [27]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=200, T_W=200, comm_matrix=generate_comm_mat(2, 16, 0, 80))

ZBV runtime: 40420.0, bubble: 0.05
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 41660.0, bubble: 0.08


In [28]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=220, T_B=240, T_W=200, comm_matrix=generate_comm_mat(2, 16, 0, 400))

ZBV runtime: 65860.0, bubble: 0.36
Best pipeline: HeuristicWaveZBPipelineV2, runtime: 51050.0, bubble: 0.17
