Skip to content

Commit

Permalink
Support embedding & Test auto-sharding on the whole BERT model & Refi…
Browse files Browse the repository at this point in the history
…ne auto-sharding interface (#49)

* merge into one commit

* clean the interface of auto-sharding

* rename
  • Loading branch information
merrymercy committed Jul 10, 2021
1 parent f2be873 commit 47fb454
Show file tree
Hide file tree
Showing 22 changed files with 1,242 additions and 317 deletions.
12 changes: 6 additions & 6 deletions benchmark/megatron/benchmark_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ def run_cmd(cmd):

benchmark_suite = [
# Batch size, seq_len, hidden size, num_layers, num_heads, dp_size, tensor_mp_size,
(16, 1024, 2304, 4, 2304//96, 4, 1),
(16, 1024, 2304, 4, 2304//96, 2, 2),
(16, 1024, 2304, 4, 2304//96, 1, 4),
(32, 1024, 2304, 4, 2304//96, 4, 1),
(32, 1024, 2304, 4, 2304//96, 2, 2),
(32, 1024, 2304, 4, 2304//96, 1, 4),

# Batch size, seq_len, hidden size, num_layers, num_heads, dp_size, tensor_mp_size,
(8, 256, 2304, 4, 2304//96, 4, 1),
(8, 256, 2304, 4, 2304//96, 2, 2),
(8, 256, 2304, 4, 2304//96, 1, 4),
(8, 256, 5760, 4, 5760//96, 4, 1),
(8, 256, 5760, 4, 5760//96, 2, 2),
(8, 256, 5760, 4, 5760//96, 1, 4),
]

def benchmark_all():
Expand Down
28 changes: 17 additions & 11 deletions benchmark/megatron/benchmark_mlp_one_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from timeit_v2 import py_benchmark

MB = 1024 ** 2
GB = 1024 ** 3


def get_memory_usage(print_info=False):
Expand All @@ -30,6 +31,18 @@ def get_memory_usage(print_info=False):
return allocated


def write_tsv(heads, values, filename, print_line=True):
"""Write tsv data to a file."""
with open(filename, "a") as fout:
fout.write("\t".join(values) + "\n")

if print_line:
line = ""
for i in range(len(heads)):
line += heads[i] + ": " + values[i] + " "
print(line)


class MultiLayerMLP(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
Expand Down Expand Up @@ -133,17 +146,10 @@ def func(record_peak=False):
# Print results
if rank == 0:
peak_mem = torch.cuda.max_memory_allocated(0)
line = f"Type: mlp\t"\
f"Case: {benchmark_case}\t"\
f"WeightMem: {weight_mem/MB:.2f}\t"\
f"PeakMem: {peak_mem/MB:.2f}\t"\
f"BackwardMem: {before_backward_mem/MB:.2f}\t"\
f"Mean Time: {np.mean(costs):.2f}\t"\
f"Std Time: {np.std(costs):.2f}"

print(line)
with open("results.tsv", "a") as fout:
fout.write(line + "\n")
heads = ["Type", "Case", "WeightMem", "PeakMem", "Mean Time", "Std Time"]
values = ["mlp", str(benchmark_case), f"{weight_mem/GB:.2f}", f"{peak_mem/GB:.2f}",
f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}"]
write_tsv(heads, values, "result_mlp.tsv")


if __name__ == "__main__":
Expand Down
18 changes: 6 additions & 12 deletions benchmark/megatron/benchmark_transformer_layer_one_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from timeit_v2 import py_benchmark
from benchmark_mlp_one_case import write_tsv

GB = 1024 ** 3

Expand Down Expand Up @@ -127,21 +128,14 @@ def func(record_act_mem=False):

# Print results
if rank == 0:
heads = ["Type", "Case", "Mesh Shape", "DDP Impl", "Peak Mem",
"Weight Mem", "ActMem", "Mean Time", "Std Time"]
heads = ["Type", "Case", "Mesh Shape", "DDP Impl", "Weight Mem",
"Peak Mem", "Mean Time", "Std Time"]
values = ["transformer-layer", str(benchmark_case[:-3]),
str(benchmark_case[-3:-1]), str(benchmark_case[-1]),
f"{peak_mem/GB:5.3f}", f"{weight_mem/GB:5.3f}",
f"{act_mem[0]/GB:5.3f}",
f"{np.mean(costs):.2f}", f"{np.std(costs):.2f}"]
f"{weight_mem/GB:5.3f}", f"{peak_mem/GB:5.3f}",
f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}"]
write_tsv(heads, values, "result_trans.tsv")

line = ""
for i in range(len(heads)):
line += heads[i] + ": " + values[i] + " "
print(line)

with open("results.tsv", "a") as fout:
fout.write("\t".join(values) + "\n")


if __name__ == "__main__":
Expand Down
15 changes: 9 additions & 6 deletions benchmark/parax/benchmark_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ray

from parax import parallelize, set_parallelize_options, testing, global_config, DeviceCluster
from parax.testing import assert_only_has_allreduce
from parax.util import write_tsv

MB = 1024 ** 2
Expand Down Expand Up @@ -108,7 +109,7 @@ def func():
func()
stmt = "func()"
repeat = 2
number = 10
number = args.number
costs = np.array(timeit.repeat(stmt, globals={**globals(), **locals()},
repeat=repeat, number=number)) / number
real_mem = testing.last_compiled_executable.total_allocation_size()
Expand All @@ -117,25 +118,26 @@ def func():
# Check sharding strategy
hlo_module = testing.last_compiled_executable.hlo_modules()[0]
hlo_ir = hlo_module.to_string()
assert_only_has_allreduce(hlo_ir)
#print("===== HLO =====")
#print(hlo_ir)

#optimizer = closure[0]
#sharding_specs = jax.tree_util.tree_map(lambda x: x.sharding_spec, optimizer)

# Log benchmark results
heads = ["Case", "PeakMem", "Objective", "Mean Time", "Std Time"]
values = [str(benchmark_case), f"{real_mem/GB:.2f}", f"{objective:.2f}",
f"{np.mean(costs):.2f}", f"{np.std(costs):.2f}"]
heads = ["Type", "Case", "PeakMem", "Objective", "Mean Time", "Std Time"]
values = ["mlp", str(benchmark_case), f"{real_mem/GB:.2f}", f"{objective:.2f}",
f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}"]
write_tsv(heads, values, "result_mlp.tsv")

physical_mesh.shutdown()


benchmark_suite = [
# Batch size, seq_len, hidden size, num_layers, dp_size, tensor_mp_size,
(16, 1024, 2304, 4, 4, 1),
(16, 1024, 2304, 4, 2, 2),
(32, 1024, 2304, 4, 4, 1),
(32, 1024, 2304, 4, 2, 2),

# Batch size, seq_len, hidden size, num_layers, dp_size, tensor_mp_size,
(8, 256, 5760, 4, 4, 1),
Expand All @@ -151,6 +153,7 @@ def benchmark_all(use_profiling):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use-profiling", action="store_true")
parser.add_argument("--number", type=int, default=10)
args = parser.parse_args()

ray.init(address="auto")
Expand Down
5 changes: 3 additions & 2 deletions benchmark/parax/benchmark_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def func():
func()
stmt = "func()"
repeat = 2
number = 10
number = args.number
costs = np.array(timeit.repeat(stmt, globals={**globals(), **locals()},
repeat=repeat, number=number)) / number
real_mem = testing.last_compiled_executable.total_allocation_size()
Expand All @@ -151,7 +151,7 @@ def func():
heads = ["Type", "Case", "Mesh Shape", "Peak Mem", "Objective", "Mean Time", "Std Time"]
values = ["transformer-layer", str(benchmark_case[:-2]), str(benchmark_case[-2:]),
f"{real_mem/GB:.3f}", f"{objective:.2f}",
f"{np.mean(costs):.2f}", f"{np.std(costs):.2f}"]
f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}"]
write_tsv(heads, values, "result_trans.tsv")

physical_mesh.shutdown()
Expand Down Expand Up @@ -195,6 +195,7 @@ def benchmark_all(use_profiling):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use-profiling", action="store_true")
parser.add_argument("--number", type=int, default=10)
args = parser.parse_args()

ray.init(address="auto")
Expand Down
7 changes: 4 additions & 3 deletions parax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def auto_parallel_callable(

return auto_sharding_callable(
fun, in_tree, out_tree_thunk, donated_invars,
physical_mesh, global_config.mesh_shape_search_mode,
logical_mesh_choices, memory_budget_per_device,
physical_mesh, logical_mesh_choices,
global_config.mesh_shape_search_mode,
memory_budget_per_device,
search_task, record_file, strategy_config, *avals
)
elif strategy == "shard_data_parallel":
Expand Down Expand Up @@ -235,7 +236,7 @@ def get_compute_key(fun, in_tree, donated_invars, *aval):
# input arguments specification to a string.
# Then compute a hash value of this string.
#
# TOOD(lmzheng): use jaxpr or hlo instead of source code?
# TODO(lmzheng): use jaxpr or hlo instead of source code?

location = fun.f.__str__().split("at")[0]
source_code = inspect.getsource(fun.f)
Expand Down
Loading

0 comments on commit 47fb454

Please sign in to comment.