Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify tfloat, tlong, tint, tbool... #312

Open
alexhernandezgarcia opened this issue Jun 4, 2024 · 1 comment
Open

Simplify tfloat, tlong, tint, tbool... #312

alexhernandezgarcia opened this issue Jun 4, 2024 · 1 comment

Comments

@alexhernandezgarcia
Copy link
Owner

Currently, the codebase uses helper methods (tfloat, tlong, tint, tbool) to convert numbers / lists / arrays into tensors with the corresponding dtype and send them to the right device. These methods are implemented in gflownet/utils/common.py.

For example, if we want to convert a batch of states into a float tensor, we can do the following:

states = tfloat(states, device=self.device, float_type=self.float)

While this has some advantages, it is rather annoying that we have to explicitly pass device and float_type, which end up making a pretty long line.

I wonder if there is a neat and simple way of changing things so that we could simply do

states = tfloat(states)
@engmubarak48
Copy link
Collaborator

which end up making a pretty long line.

I wonder what is wrong with the long line? I think formatting can take care of that if the issue is only the long line, unless there is some other issue with it.

Another option is to make the tensors to be on device and float early on, and call states = tfloat(states).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants