Official heuristic schedule can not handle insufficient mem case

In [None]:

from pipeline_config import SystemConfig
from pipeline import AutoWaveZBPipeline, HeuristicWaveZBPipeline, HeuristicZBVPipeline
from auto_schedule import WaveLikeZBDependencyGraph
from util import generate_comm_mat
from IPython.display import display

def compare_heuristic_milp_zbv(num_dev, num_mb, T_F, T_B, T_W, M_F, M_B, M_W, M_lim, comm_mat, print_schedule=True, with_official=True):
    cfg = SystemConfig(
        num_devices=num_dev,
        num_microbatches=num_mb,
        T_F=T_F,
        T_B=T_B,
        T_W=T_W,
        M_F=M_F,
        M_B=M_B,
        M_W=M_W,
        M_Limit=M_lim,
        T_C=comm_mat,
        num_chunks=2,
    )
    heur_pp = HeuristicWaveZBPipeline(cfg)
    heur_pp.schedule()
    heur_pp.solve_dependencies()
    dg = WaveLikeZBDependencyGraph(cfg)    
    dg.build_ilp()
    dg.solve_ilp(time_limit=30, warm_start=False, verbose=False)
    schedule = dg.get_schedule()
    milp_pp = AutoWaveZBPipeline(cfg)
    milp_pp.schedule(schedule)
    milp_pp.solve_dependencies()
    
    
    try:
        off_heur_pp = HeuristicZBVPipeline(cfg)
        off_heur_pp.schedule()
        off_heur_pp.solve_dependencies()
        off_heur_pp_time = off_heur_pp.get_schedule_time(device_wise=False)
    except Exception as e:
        print(f'Error in official heuristic: {e}')
        off_heur_pp_time = -1

    
    heur_pp_time = heur_pp.get_schedule_time(device_wise=False)
    milp_pp_time = milp_pp.get_schedule_time(device_wise=False)
    max_time = max(heur_pp_time, off_heur_pp_time, milp_pp_time)
    
    print(f'Heuristic Schedule: runtime={heur_pp.get_schedule_time(device_wise=True)}')
    if print_schedule:
        display(heur_pp.print_schedule(time_range=max_time, include_info=False))
    
    if with_official and print_schedule and off_heur_pp_time > 0:
        print(f'Official Heuristic Schedule: runtime={off_heur_pp.get_schedule_time(device_wise=True)}')
        display(off_heur_pp.print_schedule(time_range=max_time, include_info=False))

    print(f'MILP Schedule: runtime={milp_pp.get_schedule_time(device_wise=True)}')
    if print_schedule:
        display(milp_pp.print_schedule(time_range=max_time, include_info=False))
    

Base case

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 2,
    "M_B": -1,
    "M_W": -1,
    "M_lim": 16,
    "comm_mat": generate_comm_mat(1, 4, 0, 0)
}

compare_heuristic_milp_zbv(**args)

Base case: 1.0x mem, low comm cost

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 2,
    "M_B": -1,
    "M_W": -1,
    "M_lim": 16,
    "comm_mat": generate_comm_mat(1, 4, 0, 5)
}

compare_heuristic_milp_zbv(**args)

Base case: half mem limit 

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 2,
    "M_B": -1,
    "M_W": -1,
    "M_lim": 8,
    "comm_mat": generate_comm_mat(1, 4, 0, 0)
}

compare_heuristic_milp_zbv(**args)

Base case: imbalance B,W, half mem

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 230,
    "T_W": 170,
    "M_F": 20,
    "M_B": -8,
    "M_W": -12,
    "M_lim": 160,
    "comm_mat": generate_comm_mat(1, 4, 0, 0)
}

compare_heuristic_milp_zbv(**args)

Base case: imbalance B,W, 0.75x mem

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 230,
    "T_W": 170,
    "M_F": 20,
    "M_B": -8,
    "M_W": -12,
    "M_lim": 120,
    "comm_mat": generate_comm_mat(1, 4, 0, 0)
}

compare_heuristic_milp_zbv(**args)

Base case: imbalance B,W, half mem

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 230,
    "T_W": 170,
    "M_F": 20,
    "M_B": -8,
    "M_W": -12,
    "M_lim": 80,
    "comm_mat": generate_comm_mat(1, 4, 0, 0)
}

compare_heuristic_milp_zbv(**args)

cross-DC: 1.0x mem, low comm cost

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 20,
    "M_B": -10,
    "M_W": -10,
    "M_lim": 160,
    "comm_mat": generate_comm_mat(2, 2, 0, 200)
}
# print(generate_comm_mat(2, 4, 0, 200))

compare_heuristic_milp_zbv(**args, with_official=False)

cross-DC: inf mem, low comm cost

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 20,
    "M_B": -10,
    "M_W": -10,
    "M_lim": 2000,
    "comm_mat": generate_comm_mat(2, 2, 0, 200)
}

compare_heuristic_milp_zbv(**args, with_official=False)

cross-DC: 1.0x mem, high comm cost

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 20,
    "M_B": -10,
    "M_W": -10,
    "M_lim": 160,
    "comm_mat": generate_comm_mat(2, 2, 0, 600)
}

compare_heuristic_milp_zbv(**args, with_official=False)

cross-DC: inf mem, high comm cost

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 8,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 20,
    "M_B": -10,
    "M_W": -10,
    "M_lim": 20000,
    "comm_mat": generate_comm_mat(2, 2, 0, 600)
}

compare_heuristic_milp_zbv(**args, with_official=False)

cross-DC: inf mem, high comm cost, 1.5x mb

In [None]:
args = {
    "num_dev": 4,
    "num_mb": 12,
    "T_F": 200,
    "T_B": 200,
    "T_W": 200,
    "M_F": 20,
    "M_B": -10,
    "M_W": -10,
    "M_lim": 20000,
    "comm_mat": generate_comm_mat(2, 2, 0, 600)
}

compare_heuristic_milp_zbv(**args, with_official=False)