Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PackedGraph.to() #70

Merged
merged 3 commits into from
Feb 17, 2022
Merged

Conversation

cthoyt
Copy link
Contributor

@cthoyt cthoyt commented Feb 8, 2022

This pull request implements a to() function on the PackedGraph class. Many models expect a function like this available for shuttling between CPU and GPU.

  • Handle when a str is passed (done in bd3482b)
  • Handle when a torch.device is passed (done in 84aa285)

Question: should it be able to handle other things?

@cthoyt cthoyt marked this pull request as ready for review February 8, 2022 14:50
@KiddoZhu
Copy link
Contributor

That's a good feature! I feel it might not be robust to parse the string in that way. For example, in multi-GPU training mode, one may want to specify which GPU they want to send the data structure to.

A safer way is to convert the string (or whatever argument) into torch.device and then dispatch it with data.Graph.cpu or data.Graph.cuda. I will refine it.

@cthoyt
Copy link
Contributor Author

cthoyt commented Feb 16, 2022

It’s definitely not the goal to implement a robust string parser - just to make it possible to do a few common things very quickly. Note that the function also handles torch devices directly as well in the next part of the conditional

@KiddoZhu KiddoZhu merged commit 17477ca into DeepGraphLearning:master Feb 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants