Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Distributed Recommendation Implemention #1338

Merged
merged 10 commits into from
Jul 8, 2022
Merged

FEA: Distributed Recommendation Implemention #1338

merged 10 commits into from
Jul 8, 2022

Conversation

Ethan-TZ
Copy link
Member

@Ethan-TZ Ethan-TZ commented Jul 6, 2022

No description provided.

Copy link
Member

@hyp1231 hyp1231 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just started to review configurator.py. Please feel free to raise your concerns against the reviews.

Comment on lines 344 to 345
gpu_list = self.final_config_dict['gpu_ids']
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_list)
Copy link
Member

@hyp1231 hyp1231 Jul 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I suggest not changing the name of a widely-used arg like gpu_id. In my opinion, it's OK to use gpu_id even for multiple GPU IDs.
  • When input multiple GPU IDs, maybe it's better to use gpu_id: "1,2,3,4" (as a string) rather than gpu_id: [1, 2, 3, 4] (as a List). As users may input this arg via command line, such as python run_recbole.py --gpu_id=1,2,3,4., and the List may be difficult to input via command line.
  • Better to assign an initial value for gpu_id, or users have to specify an additional arg whenever they want to run.
Suggested change
gpu_list = self.final_config_dict['gpu_ids']
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_list)
gpu_list = self.final_config_dict['gpu_id']
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list

self.final_config_dict['device'] = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
gpu_list = self.final_config_dict['gpu_ids']
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_list)
self.final_config_dict['SingleSpec'] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that most existing arg names follow a style like single_spec. Please feel free to point out if there are some specific concerns about the naming styles.

Besides, will it be more clear if we move this line after if 'local_rank' not in self.final_config_dict: and before else:?

@@ -16,7 +16,6 @@
import os
import sys
import yaml
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that regardless of the existence of local_rank, we need to import torch. So what are the concerns of removing this line here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to make the setting of environment variables effective, we must put the os.environ["CUDA_VISIBLE_DEVICES"] behind import torch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks.

Comment on lines 355 to 356
torch.distributed.init_process_group(backend='nccl', rank = self.final_config_dict['local_rank'], world_size = self.final_config_dict['world_size'],
init_method='tcp://' + self.final_config_dict['ip'] + ':' + str(self.final_config_dict['port']))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take care of the coding style. [PEP8]

Suggested change
torch.distributed.init_process_group(backend='nccl', rank = self.final_config_dict['local_rank'], world_size = self.final_config_dict['world_size'],
init_method='tcp://' + self.final_config_dict['ip'] + ':' + str(self.final_config_dict['port']))
torch.distributed.init_process_group(
backend='nccl', rank=self.final_config_dict['local_rank'],
world_size=self.final_config_dict['world_size'],
init_method='tcp://' + self.final_config_dict['ip'] + ':' + str(self.final_config_dict['port']))

WX20220707-145836@2x

WX20220707-145913@2x

if 'local_rank' not in self.final_config_dict:
import torch
self.final_config_dict['local_rank'] = 0
self.final_config_dict['device'] = torch.device("cpu") if len(gpu_list) == 0 else torch.device("cuda")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check with torch.cuda.is_available()?

@Ethan-TZ Ethan-TZ merged commit 063bfe7 into RUCAIBox:1.1.x Jul 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants