Skip to content

Commit

Permalink
Add default config and update dataset in gpups (#43327)
Browse files Browse the repository at this point in the history
* gpups default config and dataset

* codestyle

* add unittest

* code style
  • Loading branch information
esythan committed Jun 13, 2022
1 parent 5d48528 commit 24ea1dd
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
12 changes: 8 additions & 4 deletions python/paddle/distributed/fleet/dataset/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ def init(self, **kwargs):
pipe_command = kwargs.get("pipe_command", "cat")
download_cmd = kwargs.get("download_cmd", "cat")

if self.use_ps_gpu:
data_feed_type = "SlotRecordInMemoryDataFeed"
else:
data_feed_type = "MultiSlotInMemoryDataFeed"
self._set_feed_type(data_feed_type)

super(InMemoryDataset, self).init(batch_size=batch_size,
thread_num=thread_num,
use_var=use_var,
Expand All @@ -592,10 +598,6 @@ def init(self, **kwargs):
fs_ugi=fs_ugi,
download_cmd=download_cmd)

data_feed_type = kwargs.get("data_feed_type",
"MultiSlotInMemoryDataFeed")
self._set_feed_type(data_feed_type)

if kwargs.get("queue_num", -1) > 0:
queue_num = kwargs.get("queue_num", -1)
self._set_queue_num(queue_num)
Expand All @@ -605,6 +607,8 @@ def _set_feed_type(self, data_feed_type):
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
if (self.proto_desc.name == "SlotRecordInMemoryDataFeed"):
self.dataset = core.Dataset("SlotRecordDataset")

def _prepare_to_run(self):
"""
Expand Down
16 changes: 11 additions & 5 deletions python/paddle/distributed/ps/the_one_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _set(self, accessor_proto, varname, program_id, context):
if not accessor_proto.HasField("accessor_class"):
# DownpourSparseValueAccessor
if context['use_ps_gpu']:
accessor_proto.accessor_class = "CtrCommonAccessor"
accessor_proto.accessor_class = "CtrDymfAccessor"
else:
accessor_proto.accessor_class = "SparseAccessor"
if not accessor_proto.HasField("fea_dim"):
Expand Down Expand Up @@ -601,10 +601,16 @@ def _set(self, table_proto):
if usr_table_proto.HasField("shard_num"):
table_proto.shard_num = usr_table_proto.shard_num
else:
table_proto.shard_num = 1000
warnings.warn(
"The shard_num of sparse table is not set, use default value 1000."
)
if self.context['use_ps_gpu']:
table_proto.shard_num = 37
warnings.warn(
"The shard_num of sparse table is not set, use default value 37 in gpups."
)
else:
table_proto.shard_num = 1000
warnings.warn(
"The shard_num of sparse table is not set, use default value 1000 in cpups."
)

if usr_table_proto.accessor.ByteSize() == 0:
warnings.warn(
Expand Down
39 changes: 39 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dist_fleet_ps11.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,45 @@ def test(self):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss)

def test_gpups_dataset(self):
"""
Testcase for GPUPS InMemoryDataset .
"""
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)

slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(name=slot,
shape=[1],
dtype="int64",
lod_level=1)
slots_vars.append(var)

dataset = paddle.distributed.InMemoryDataset()
dataset._set_use_ps_gpu(True)
dataset.init(batch_size=32,
thread_num=3,
pipe_command="cat",
use_var=slots_vars)
dataset.set_filelist([
"test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"
])

os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")


if __name__ == '__main__':
unittest.main()

0 comments on commit 24ea1dd

Please sign in to comment.