compare each pp schedule: simple case, T_F=1, T_B=2, M = inf, i

In [None]:
from typing import Type
from IPython.display import display
from pipeline import GpipePipeline, Hanayo1F1BPipeline, HeuristicWaveZBPipeline, HeuristicWaveZBPipelineV2, HeuristicZBVPipeline, Interleaved1F1BPipeline, OneFOneBPipeline, Pipeline, ZBH1Pipeline
from pipeline_config import SystemConfig
from util import generate_comm_mat


def test_pipeline(
    PipelineClass: Type[Pipeline],
    sys_config: SystemConfig,
) -> None:
    pipeline = PipelineClass(sys_config)
    pipeline.schedule()
    pipeline.solve_dependencies()
    print(f"{PipelineClass.__name__} runtime: {pipeline.get_schedule_time(device_wise=True)}, bubble: {pipeline.get_bubble_ratio(device_wise=True):.2f}")
    display(pipeline.print_schedule(include_info=True if sys_config.num_devices < 10 else False))

def test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=200, T_W=200, comm_matrix=None):
    num_chunks = 2
    T_F_chunk = T_F / num_chunks
    T_B_chunk = T_B / num_chunks
    T_W_chunk = T_W / num_chunks
    M_F = 2
    M_B = -1
    M_W = -1
    mem_factor = [2]
    M_limit = num_dev * num_chunks * M_F

    if comm_matrix is None:
        comm_matrix = generate_comm_mat(1, num_dev, 0, 0)
    # gpipe, 1f1b,
    sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F,
        T_B=T_B + T_W,
        T_C=comm_matrix,
    )
    # iv1f1b, hanayo
    interleaved_sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F_chunk,
        T_B=T_B_chunk + T_W_chunk,
        T_C=comm_matrix,
        num_chunks=num_chunks,
    )

    zbh1_sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F,
        T_B=T_B,
        T_W=T_W,
        T_C=comm_matrix,
    )
    heur_zbv_sys_config = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F_chunk,
        T_B=T_B_chunk,
        T_W=T_W_chunk,
        T_C=comm_matrix,
        num_chunks=num_chunks,
        M_F=M_F,
        M_B=M_B,
        M_W=M_W,
        M_Limit=M_limit,
    )
    additional_cfg_heur = [SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_microbatches,
        T_F=T_F_chunk,
        T_B=T_B_chunk,
        T_W=T_W_chunk,
        T_C=comm_matrix,
        num_chunks=num_chunks,
        M_F=M_F,
        M_B=M_B,
        M_W=M_W,
        M_Limit=M_limit * f,
    ) for f in mem_factor
    ]
    test_pipeline(OneFOneBPipeline, sys_config)
    test_pipeline(GpipePipeline, sys_config)
    test_pipeline(Interleaved1F1BPipeline, interleaved_sys_config)
    test_pipeline(Hanayo1F1BPipeline, interleaved_sys_config)
    test_pipeline(ZBH1Pipeline, zbh1_sys_config)
    test_pipeline(HeuristicZBVPipeline, heur_zbv_sys_config)
    test_pipeline(HeuristicWaveZBPipeline, heur_zbv_sys_config)
    test_pipeline(HeuristicWaveZBPipelineV2, heur_zbv_sys_config)
    for factor, cfg in zip(mem_factor, additional_cfg_heur):
        print(f'Heuristic: mem_factor: {factor}')
        test_pipeline(HeuristicWaveZBPipeline, cfg)
        test_pipeline(HeuristicWaveZBPipelineV2, cfg)
    
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=200, T_B=200, T_W=200, comm_matrix=generate_comm_mat(2, 2, 0, 800))

In [None]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=200, T_W=200, comm_matrix=generate_comm_mat(1, 32, 0, 0))

In [None]:
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=200, T_B=250, T_W=150, comm_matrix=generate_comm_mat(1, 4, 0, 0))

In [None]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=250, T_W=150, comm_matrix=generate_comm_mat(1, 32, 0, 0))

In [None]:
test_basic_schedule(num_dev=4, num_microbatches=8, T_F=180, T_B=220, T_W=200, comm_matrix=generate_comm_mat(1, 4, 0, 0))

In [None]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=200, T_B=200, T_W=200, comm_matrix=generate_comm_mat(2, 16, 0, 800))

In [None]:
test_basic_schedule(num_dev=32, num_microbatches=64, T_F=220, T_B=240, T_W=200, comm_matrix=generate_comm_mat(2, 16, 0, 800))