Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easier environment definition #143

Merged
merged 30 commits into from
Nov 25, 2023
Merged

Easier environment definition #143

merged 30 commits into from
Nov 25, 2023

Conversation

josephdviviano
Copy link
Collaborator

  • added helper functions to DiscreteEnv to make mask definition easier.
  • added helper function for mask casting to base class.
  • changed default behaviour of the log_reward and reward methods.

@josephdviviano josephdviviano self-assigned this Oct 20, 2023
@josephdviviano josephdviviano changed the base branch from master to stable October 20, 2023 17:42
Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR ! Thanks. Could you also run pre-commit run --all at the end ? Some files will get modified

@@ -371,3 +384,55 @@ def _extend(masks, first_dim):

self.forward_masks = _extend(self.forward_masks, required_first_dim)
self.backward_masks = _extend(self.backward_masks, required_first_dim)

# The helper methods are convienience functions for common mask operations.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

def set_nonexit_masks(self, cond, allow_exit: bool = False):
"""Sets the allowable actions according to cond, appending the exit mask.

A convienience function for common mask operations.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just apparently can't spell this one :)

A convienience function for common mask operations.

Args:
cond: a boolean of shape (batch_shape,) + (state_shape,), which
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this true only when state_shape = n_actions - 1 ?

I think you meant n_actions - 1 rather than state_shape

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right -- good catch

src/gfn/env.py Outdated
@@ -184,12 +190,12 @@ def backward_step(
return new_states

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or log_reward needs to be implemented."""
return torch.exp(self.log_reward(final_states))
"""This (and potentially log_reward) needs to be implemented."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why and ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed the docs

self.backward_masks,
)

self.set_default_typing()
self.forward_masks[..., :-1] = self.tensor != env.height - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should use set_nonexit_masks for this line ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, good catch!

@saleml
Copy link
Collaborator

saleml commented Oct 23, 2023

Also, it seems like there are GitHub actions now. Do you happen to know why the checks fail ?

@josephdviviano
Copy link
Collaborator Author

josephdviviano commented Oct 23, 2023 via email

@saleml saleml mentioned this pull request Oct 25, 2023
@josephdviviano
Copy link
Collaborator Author

All comments fixed -- just waiting to see if the checks pass :)

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -23,16 +23,21 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea !

self.is_discrete = True
self.log_reward_clip = log_reward_clip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary

@@ -303,6 +303,19 @@ def __init__(
self.forward_masks = cast(torch.Tensor, forward_masks)
self.backward_masks = cast(torch.Tensor, backward_masks)

self.set_default_typing()

def set_default_typing(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea!

super().__init__(**kwargs)
self.logZ_value = nn.Parameter(logZ_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only place BoxStateFlowModule is used is in train_box.py:

logZ = torch.tensor(0.0, device=env.device, requires_grad=True)
        # We need a LogStateFlowEstimator

        module = BoxStateFlowModule(
            input_dim=env.preprocessor.output_dim,
            output_dim=1,
            hidden_dim=args.hidden_dim,
            n_hidden_layers=args.n_hidden,
            torso=None,  # We do not tie the parameters of the flow function to PF
            logZ_value=logZ,
        )

Naive pytorch question: why do we need nn.Parameter ?

def true_reward(
self, final_states: DiscreteStates
) -> TT["batch_shape", torch.float]:
def reward(self, final_states: DiscreteStates) -> TT["batch_shape", torch.float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right ! true_reward was useless

@@ -371,3 +384,55 @@ def _extend(masks, first_dim):

self.forward_masks = _extend(self.forward_masks, required_first_dim)
self.backward_masks = _extend(self.backward_masks, required_first_dim)

# The helper methods are convenience functions for common mask operations.
def set_nonexit_masks(self, cond, allow_exit: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there other places than hypergrid.py where this is used ?

dim=-1,
).bool()

def init_forward_masks(self, set_ones: bool = True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect

@josephdviviano
Copy link
Collaborator Author

Tests pass!

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very important PR. Thank you @josephdviviano

@saleml
Copy link
Collaborator

saleml commented Nov 24, 2023

Should this be merged to master or to stable ?
I thought the logic is: put everything in master, and once we're happy with master, we push master to stable.

@saleml saleml changed the base branch from stable to master November 25, 2023 16:58
@saleml saleml merged commit 1603723 into master Nov 25, 2023
3 checks passed
@josephdviviano josephdviviano deleted the easier_environment_definition branch February 16, 2024 19:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants