Skip to content

Commit

Permalink
pass all test
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 10, 2021
1 parent d2fc4a1 commit 39c2bd8
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
9 changes: 6 additions & 3 deletions benchmark/parax/benchmark_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ray

from parax import parallelize, set_parallelize_options, testing, global_config, DeviceCluster
from parax.testing import assert_only_has_allreduce
from parax.util import write_tsv

MB = 1024 ** 2
Expand Down Expand Up @@ -108,7 +109,7 @@ def func():
func()
stmt = "func()"
repeat = 2
number = 10
number = args.number
costs = np.array(timeit.repeat(stmt, globals={**globals(), **locals()},
repeat=repeat, number=number)) / number
real_mem = testing.last_compiled_executable.total_allocation_size()
Expand All @@ -117,15 +118,16 @@ def func():
# Check sharding strategy
hlo_module = testing.last_compiled_executable.hlo_modules()[0]
hlo_ir = hlo_module.to_string()
assert_only_has_allreduce(hlo_ir)
#print("===== HLO =====")
#print(hlo_ir)

#optimizer = closure[0]
#sharding_specs = jax.tree_util.tree_map(lambda x: x.sharding_spec, optimizer)

# Log benchmark results
heads = ["Case", "PeakMem", "Objective", "Mean Time", "Std Time"]
values = [str(benchmark_case), f"{real_mem/GB:.2f}", f"{objective:.2f}",
heads = ["Type", "Case", "PeakMem", "Objective", "Mean Time", "Std Time"]
values = ["mlp", str(benchmark_case), f"{real_mem/GB:.2f}", f"{objective:.2f}",
f"{np.mean(costs):.3f}", f"{np.std(costs):.3f}"]
write_tsv(heads, values, "result_mlp.tsv")

Expand All @@ -151,6 +153,7 @@ def benchmark_all(use_profiling):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use-profiling", action="store_true")
parser.add_argument("--number", type=int, default=10)
args = parser.parse_args()

ray.init(address="auto")
Expand Down
3 changes: 2 additions & 1 deletion benchmark/parax/benchmark_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def func():
func()
stmt = "func()"
repeat = 2
number = 10
number = args.number
costs = np.array(timeit.repeat(stmt, globals={**globals(), **locals()},
repeat=repeat, number=number)) / number
real_mem = testing.last_compiled_executable.total_allocation_size()
Expand Down Expand Up @@ -195,6 +195,7 @@ def benchmark_all(use_profiling):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use-profiling", action="store_true")
parser.add_argument("--number", type=int, default=10)
args = parser.parse_args()

ray.init(address="auto")
Expand Down
3 changes: 2 additions & 1 deletion parax/auto_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ def _invoke_compilation(logical_mesh):
# Solver options
"auto_sharding::enable": True,
"auto_sharding::memory_budget_per_device": memory_budget_per_device,
"auto_sharding::force_all_gather_cost": global_config.disable_all_gather,
"auto_sharding::force_all_gather_cost": not global_config.allow_all_gather,
"auto_sharding::all_gather_cost": 1e10,
"auto_sharding::allow_recompute_heavy_op": global_config.allow_recompute_heavy_op,

# Device mesh
"auto_sharding::device_mesh_ids": logical_mesh.flatten_ids,
Expand Down
3 changes: 2 additions & 1 deletion parax/global_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self):
self.cache_auto_sharding_ilp_solution = False

########## Options for auto-sharding solver ##########
self.disable_all_gather = False # Do not allow all-gather during re-sharding.
self.allow_all_gather = True # Do not allow all-gather during re-sharding.
self.allow_recompute_heavy_op = False # Allow replicated dot computation.

########## Options for benchmark ##########
# If true, the system is allowed to use dummy values during
Expand Down
12 changes: 5 additions & 7 deletions tests/test_auto_sharding_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
assert_replicated_row_partitioned,
assert_row_partitioned)

MB = 1024 ** 2


def inspect_params(optimizer):
"""For debug usage."""
Expand Down Expand Up @@ -336,10 +334,9 @@ def test_bert_layer_2d_mesh(self):
num_layers = 2
batch_size = 8
seq_len = 8
hidden_size = 256
hidden_size = 128
num_heads = 8
deterministic = False
global_config.disable_all_gather = True # Temporary hack

# Test on different logical mesh shapes
mesh_shape = [2, 2]
Expand All @@ -353,7 +350,7 @@ def test_bert_layer_2d_mesh(self):
expected = sum(device_mesh.all_reduce_cost(
np.prod(x.shape) * 4 / mesh_shape[1], 0) for x in params) +\
device_mesh.all_reduce_cost(
batch_size * seq_len * hidden_size * 4 / mesh_shape[0], 1)
batch_size * seq_len * hidden_size * 4 / mesh_shape[0], 1) * (num_layers * 4 - 1)
assert_close(objective, expected)
assert_only_has_allreduce(hlo_ir)

Expand Down Expand Up @@ -476,7 +473,7 @@ def test_bert_mlm_model_parallel(self):
num_layers = 2
vocab_size = 512
deterministic = False
global_config.disable_all_gather = True # Temporary hack
global_config.allow_all_gather = False # Temporary hack

# Test on different logical mesh shapes
for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]):
Expand Down Expand Up @@ -532,7 +529,8 @@ def test_bert_mlm_2d_mesh(self):
num_layers = 2
vocab_size = 4096
deterministic = False
global_config.disable_all_gather = True # Temporary hack
global_config.allow_all_gather = False # Temporary hack
global_config.allow_recompute_heavy_op = True

mesh_shape = [2, 2]
device_mesh = self.get_device_mesh(mesh_shape, [2, 2], [1, 0.1])
Expand Down

0 comments on commit 39c2bd8

Please sign in to comment.