Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BIG PR: Simplification and homegenisation of environments, addition of tests, other clean up changes #101

Merged
merged 118 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from 111 commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
ebe36ad
changes to make uniform proxy work; weirdly there is still a seemingl…
alexhernandezgarcia Feb 24, 2023
a98d80f
Merge branch 'continuous' into uniform-proxy-fix
alexhernandezgarcia Feb 28, 2023
3fe2d83
wip: init grid tests
alexhernandezgarcia Feb 28, 2023
2eeec19
Merge branch 'gdown' into tests-and-simplify-envs
alexhernandezgarcia Feb 28, 2023
4af828c
delete old code
alexhernandezgarcia Feb 28, 2023
df87514
add missing import
alexhernandezgarcia Feb 28, 2023
504bb28
new way of setting up proxy
alexhernandezgarcia Feb 28, 2023
ea4f5b7
Merge branch 'continuous' into tests-and-simplify-envs
alexhernandezgarcia Feb 28, 2023
10a2562
remove old torus rounds script
alexhernandezgarcia Feb 28, 2023
c181d0b
remove unnecessary repeated stuff in alaninedipeptide env
alexhernandezgarcia Feb 28, 2023
365f425
remove unnecessary repeated stuff in aptamers env
alexhernandezgarcia Feb 28, 2023
c40649a
move to env base the common operations and simplify grid for now
alexhernandezgarcia Feb 28, 2023
f8a19f6
logger base config includes all args to work by itself
alexhernandezgarcia Mar 1, 2023
bbe2b48
revert copy to deepcopy
alexhernandezgarcia Mar 1, 2023
6f259ab
buffer in a separate file in utils
alexhernandezgarcia Mar 1, 2023
1cabfdf
match default config with default args in env base£
alexhernandezgarcia Mar 1, 2023
cf7216c
basic gflownet test
alexhernandezgarcia Mar 1, 2023
573fc03
isort
alexhernandezgarcia Mar 1, 2023
5ea3522
wip: test to check default configs of envs
alexhernandezgarcia Mar 1, 2023
31f9aa1
use pytest repeat
alexhernandezgarcia Mar 1, 2023
b3a5586
remove true_density
alexhernandezgarcia Mar 1, 2023
f7a26d3
remove np2df
alexhernandezgarcia Mar 1, 2023
bf9f6f3
black
alexhernandezgarcia Mar 1, 2023
f706131
remove make_train and make_test methods
alexhernandezgarcia Mar 1, 2023
2b781c5
get_action_space <- get_actions_space
alexhernandezgarcia Mar 1, 2023
78d11e1
in env base.py add docstring, change order of methods, add typing, etc.
alexhernandezgarcia Mar 1, 2023
60f117d
remove no_eos_mask from aptamers
alexhernandezgarcia Mar 1, 2023
e541184
add types
alexhernandezgarcia Mar 2, 2023
6b4b30b
format
alexhernandezgarcia Mar 2, 2023
f13ab4b
types and done optional in set state
alexhernandezgarcia Mar 2, 2023
4766543
add assertions in test test__get_parents_step_get_mask__are_compatible
alexhernandezgarcia Mar 2, 2023
176da09
add get_max_traj_len to grid
alexhernandezgarcia Mar 3, 2023
4fe182e
add backward sampling test
alexhernandezgarcia Mar 3, 2023
d28169a
add repetitions to test backward sampling
alexhernandezgarcia Mar 3, 2023
b2ab4e9
types and fix state2oracle
alexhernandezgarcia Mar 3, 2023
a90b4b6
grid add tests state 2 oracle
alexhernandezgarcia Mar 3, 2023
429f07b
grid remove reset
alexhernandezgarcia Mar 3, 2023
f3b05b2
add test state conversions are reversible
alexhernandezgarcia Mar 3, 2023
80046bc
remove todo comment; specify dim in typing
alexhernandezgarcia Mar 3, 2023
f34f5fb
tests of sample_actions and logprobs
alexhernandezgarcia Mar 3, 2023
93fcd20
black
alexhernandezgarcia Mar 3, 2023
086100d
get_max_traj_length <- get_max_traj_len and make self.max_traj_length…
alexhernandezgarcia Mar 3, 2023
77802c6
tests get_parents at source and end states
alexhernandezgarcia Mar 3, 2023
a5860f9
correct name
alexhernandezgarcia Mar 3, 2023
ff523b8
test step returns invalid if done
alexhernandezgarcia Mar 3, 2023
1b50ef5
create common function that runs all common tests
alexhernandezgarcia Mar 3, 2023
c87db62
add additional checks to grid step
alexhernandezgarcia Mar 3, 2023
58cdfd7
Merge branch 'valids-batch' into tests-and-simplify-envs
alexhernandezgarcia Mar 3, 2023
e045704
add todo
alexhernandezgarcia Mar 3, 2023
a4a1c31
format docstring
alexhernandezgarcia Mar 3, 2023
16d51c2
add comments and change order of lines in init
alexhernandezgarcia Mar 3, 2023
1890654
docs
alexhernandezgarcia Mar 3, 2023
ef172a8
spacegroup: update init, typing, docs
alexhernandezgarcia Mar 3, 2023
5dff51d
add common tests to spacegroup
alexhernandezgarcia Mar 3, 2023
ce139a2
fix bug with n_actions in step func
alexhernandezgarcia Mar 3, 2023
3d940b9
move setting of conditioned properties into new function
alexhernandezgarcia Mar 3, 2023
6b24cb8
add statetorch statebatch 2 oracle
alexhernandezgarcia Mar 3, 2023
7392400
isort and black envs
alexhernandezgarcia Mar 3, 2023
60e1681
self.d_action_space for the len' WIP of actions2indices
alexhernandezgarcia Mar 3, 2023
4a18a3c
fix actions2indices
alexhernandezgarcia Mar 4, 2023
d84afdc
test actions2indices
alexhernandezgarcia Mar 4, 2023
4194ad5
docstring
alexhernandezgarcia Mar 4, 2023
1ffc1d5
wip: start work to change action space of hypergrid
alexhernandezgarcia Mar 8, 2023
6110f4f
completed change of action format in grid
alexhernandezgarcia Mar 9, 2023
f72ee9f
add grid common test with env with extended action space
alexhernandezgarcia Mar 9, 2023
87ef339
Update gflownet/envs/grid.py
alexhernandezgarcia Mar 10, 2023
93447f0
fix get_action_space and add explicit test for it
alexhernandezgarcia Mar 11, 2023
2c31d02
put assertions first all together
alexhernandezgarcia Mar 11, 2023
5cab20e
if max_dim_per_action is -1 then it is n_dim
alexhernandezgarcia Mar 11, 2023
a6cb395
Merge branch 'actions-grid' of github.com:alexhernandezgarcia/gflowne…
alexhernandezgarcia Mar 11, 2023
0fb3bbe
Update config/env/grid.yaml
alexhernandezgarcia Mar 11, 2023
64fb531
Merge pull request #99 from alexhernandezgarcia/actions-grid
alexhernandezgarcia Mar 14, 2023
23f50bf
mypy (partial) in torus
alexhernandezgarcia Mar 22, 2023
307235f
Merge branch 'tests-and-simplify-envs' of github.com:alexhernandezgar…
alexhernandezgarcia Mar 22, 2023
c8b1d5d
remove Constant comment
alexhernandezgarcia Mar 22, 2023
aa979dd
remove prints
alexhernandezgarcia Mar 22, 2023
00ce208
mypy and docstring
alexhernandezgarcia Mar 22, 2023
e8a0bf0
remove reset
alexhernandezgarcia Mar 22, 2023
e350761
mypy to get_parents
alexhernandezgarcia Mar 22, 2023
fcab574
update get_parents
alexhernandezgarcia Mar 22, 2023
9debc8a
update get_parents, step get_all_terminating_states
alexhernandezgarcia Mar 22, 2023
f0cfe03
tests with more complex envs
alexhernandezgarcia Mar 22, 2023
74b4887
eos is an action
alexhernandezgarcia Mar 23, 2023
128ab75
update eos format in test
alexhernandezgarcia Mar 24, 2023
17a2ccd
add fixed_distribution and random_distribution as arguments of the ba…
alexhernandezgarcia Mar 24, 2023
86a03b7
black
alexhernandezgarcia Mar 24, 2023
d0890aa
d -> dim
alexhernandezgarcia Mar 24, 2023
a025304
simplify and update htorus
alexhernandezgarcia Mar 24, 2023
f646a1e
remove reset and move copy down
alexhernandezgarcia Mar 24, 2023
c3a0ba5
make incr in actions space 0 instead of None; include explicitly fixe…
alexhernandezgarcia Mar 24, 2023
b6439b4
get_parents return no parents no actions for source state and unpack …
alexhernandezgarcia Mar 24, 2023
ad8fb13
mypy
alexhernandezgarcia Mar 24, 2023
a49392e
change action format and eos - now actions have just length n_dim and…
alexhernandezgarcia Mar 25, 2023
6bbeb6e
fix issue with dict params and define eos in get_action_space to over…
alexhernandezgarcia Mar 25, 2023
8e54472
mask_stop_actions -> mask_invalid_actions
alexhernandezgarcia Mar 25, 2023
45cf382
add preprocess to readable2state but not working always
alexhernandezgarcia Mar 25, 2023
6d4320c
separate group of tests for continuous and sample actions using sampl…
alexhernandezgarcia Mar 25, 2023
448df92
basic tests for ctorus and htorus (need to add more)
alexhernandezgarcia Mar 25, 2023
996cf58
add setup to torus proxy
alexhernandezgarcia Mar 25, 2023
442c557
add missing call to setup_proxy() in envs base
alexhernandezgarcia Mar 25, 2023
399f65b
black
alexhernandezgarcia Mar 25, 2023
3a050a3
remove copy because it is in parent class
alexhernandezgarcia Mar 25, 2023
96b9e89
wrap long lines
alexhernandezgarcia Mar 25, 2023
d645458
make eos action and simplify like other envs
alexhernandezgarcia Mar 25, 2023
4f812dc
remove comment line
alexhernandezgarcia Mar 25, 2023
564b8cd
black
alexhernandezgarcia Mar 25, 2023
89f9e1c
isort
alexhernandezgarcia Mar 25, 2023
9d40888
self.d_action_space -> self.action_space_dim
alexhernandezgarcia Mar 26, 2023
5d98f42
d -> dim in grid
alexhernandezgarcia Mar 26, 2023
35d7762
black, isort
alexhernandezgarcia Mar 26, 2023
e623571
fix old handling of eos action in backward sampling
alexhernandezgarcia Mar 28, 2023
b576ade
black; isort --profile black
alexhernandezgarcia Mar 31, 2023
d17c632
Update gflownet/proxy/base.py
alexhernandezgarcia Mar 31, 2023
cd63e6b
Update gflownet/proxy/base.py
alexhernandezgarcia Mar 31, 2023
3d71a0a
fix issues introduced during review
alexhernandezgarcia Mar 31, 2023
f364ba2
proxy method setup is not abstract because it is not always necessary
alexhernandezgarcia Mar 31, 2023
249d7d7
Update gflownet/envs/base.py
alexhernandezgarcia Mar 31, 2023
f31990c
Update gflownet/envs/base.py
alexhernandezgarcia Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions config/env/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ _target_: gflownet.envs.base.GFlowNetEnv
# Reward function: power or boltzmann
# boltzmann: exp(-1.0 * reward_beta * proxy)
# power: (-1.0 * proxy / reward_norm) ** self.reward_beta
reward_func: boltzmann
# identity: proxy
reward_func: identity
# Minimum reward
reward_min: 1e-8
# Beta parameter of the reward function
reward_beta: 1.0
# Reward normalization for "power" reward function
reward_norm: 1.0
# If > 0, reward_norm = reward_norm_std_mult * std(energies)
reward_norm_std_mult: 0.0
proxy_state_format: oracle
reward_norm_std_mult: 8
# Buffer
buffer:
replay_capacity: 10
Expand Down
7 changes: 4 additions & 3 deletions config/env/grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ func: corners
n_dim: 2
# Number of cells per dimension
length: 3
# Minimum and maximum number of steps in the action space
min_step_len: 1
max_step_len: 1
# Maximum increment per each dimension that can be done by one action
max_increment: 1
# Maximum number of dimensions that can be incremented by one action
max_dim_per_action: 1
# Mapping coordinates
cell_min: -1
cell_max: 1
Expand Down
7 changes: 4 additions & 3 deletions config/env/torus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ n_dim: 2
n_angles: 8
# Maximum number of rounds
length_traj: 12
# Minimum and maximum number of steps in the action space
min_step_len: 1
max_step_len: 1
# Maximum increment per each dimension that can be done by one action
max_increment: 1
# Maximum number of dimensions that can be incremented by one action
max_dim_per_action: 1
# Buffer
buffer:
data_path: null
Expand Down
21 changes: 0 additions & 21 deletions config/env/torus_rounds.yaml

This file was deleted.

4 changes: 3 additions & 1 deletion config/logger/base.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
_target_: logger.Logger
_target_: gflownet.utils.logger.Logger

do:
online: False
times: False

project_name: "GFlowNet"

# Train metrics
train:
period: 1
Expand Down
2 changes: 0 additions & 2 deletions config/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,5 @@ _target_: gflownet.utils.logger.Logger
do:
online: True

project_name: "GFlowNet"

tags:
- gflownet
27 changes: 27 additions & 0 deletions config/tests.yaml
alexhernandezgarcia marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defaults:
- _self_
- env: grid
- gflownet: flowmatch
- proxy: uniform
- logger: base
- user: alex

# Device
device: cpu
# Float precision
float_precision: 32
# Number of objects to sample at the end of training
n_samples: 1
# Random seeds
seed: 0

# Hydra config
hydra:
# See: https://hydra.cc/docs/configure_hydra/workdir/
run:
dir: ${user.logdir.root}/${now:%Y-%m-%d_%H-%M-%S}_tests
job:
# See: https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir/
# See: https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory/#disable-changing-current-working-dir-to-jobs-output-dir
chdir: True

59 changes: 10 additions & 49 deletions gflownet/envs/alaninedipeptide.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from copy import deepcopy
from typing import List, Tuple

import numpy as np
import numpy.typing as npt
import torch

from copy import deepcopy
from typing import List, Tuple
from torchtyping import TensorType

from gflownet.envs.ctorus import ContinuousTorus
Expand All @@ -20,48 +20,17 @@ def __init__(
self,
path_to_dataset,
url_to_dataset,
length_traj=1,
fixed_distribution=dict,
random_distribution=dict,
vonmises_min_concentration=1e-3,
env_id=None,
reward_beta=1,
reward_norm=1.0,
reward_norm_std_mult=0,
reward_func="boltzmann",
denorm_proxy=False,
energies_stats=None,
proxy=None,
oracle=None,
policy_encoding_dim_per_angle=None,
n_comp=3,
**kwargs,
):
self.atom_positions_dataset = AtomPositionsDataset(path_to_dataset, url_to_dataset)
self.atom_positions_dataset = AtomPositionsDataset(
path_to_dataset, url_to_dataset
)
atom_positions = self.atom_positions_dataset.sample()
self.conformer = ConformerBase(
atom_positions, constants.ad_smiles, constants.ad_free_tas
)
n_dim = len(self.conformer.freely_rotatable_tas)
super(AlanineDipeptide, self).__init__(
n_dim=n_dim,
length_traj=length_traj,
fixed_distribution=fixed_distribution,
random_distribution=random_distribution,
vonmises_min_concentration=vonmises_min_concentration,
env_id=env_id,
reward_beta=reward_beta,
reward_norm=reward_norm,
reward_norm_std_mult=reward_norm_std_mult,
reward_func=reward_func,
denorm_proxy=denorm_proxy,
energies_stats=energies_stats,
proxy=proxy,
oracle=oracle,
policy_encoding_dim_per_angle=policy_encoding_dim_per_angle,
n_comp=n_comp,
**kwargs,
)
super().__init__(**kwargs)
self.sync_conformer_with_state()

def sync_conformer_with_state(self, state: List = None):
Expand All @@ -71,13 +40,7 @@ def sync_conformer_with_state(self, state: List = None):
self.conformer.set_torsion_angle(ta, state[idx])
return self.conformer

def copy(self):
# return an instance of the environment
return deepcopy(self)

def statetorch2proxy(
self, states: TensorType["batch", "state_dim"]
) -> npt.NDArray:
def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray:
"""
Prepares a batch of states in torch "GFlowNet format" for the oracle.
"""
Expand All @@ -88,16 +51,14 @@ def statetorch2proxy(
np_states = states.cpu().numpy()
return np_states[:, :-1]

def statebatch2proxy(
self, states: List[List]
) -> npt.NDArray:
def statebatch2proxy(self, states: List[List]) -> npt.NDArray:
"""
Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where
each state is a row of length n_dim with an angle in radians. The n_actions
item is removed.
"""
return np.array(states)[:, :-1]

def statetorch2oracle(
self, states: TensorType["batch", "state_dim"]
) -> List[Tuple[npt.NDArray, npt.NDArray]]:
Expand Down
121 changes: 13 additions & 108 deletions gflownet/envs/aptamers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
Classes to represent aptamers environments
"""
from typing import List
import itertools
import time
from typing import List

import numpy as np
import numpy.typing as npt
import pandas as pd

from gflownet.envs.base import GFlowNetEnv
import time


class AptamerSeq(GFlowNetEnv):
Expand Down Expand Up @@ -51,47 +53,27 @@ def __init__(
n_alphabet=4,
min_word_len=1,
max_word_len=1,
proxy=None,
oracle=None,
reward_beta=1,
env_id=None,
energies_stats=None,
reward_norm=1.0,
reward_norm_std_mult=0.0,
reward_func="power",
denorm_proxy=False,
**kwargs,
):
super(AptamerSeq, self).__init__(
env_id,
reward_beta,
reward_norm,
reward_norm_std_mult,
reward_func,
energies_stats,
denorm_proxy,
proxy,
oracle,
**kwargs,
)
super().__init__()
self.source = []
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.n_alphabet = n_alphabet
self.min_word_len = min_word_len
self.max_word_len = max_word_len
self.action_space = self.get_actions_space()
self.eos = len(self.action_space)
self.action_space = self.get_action_space()
self.eos = self.action_space_dim
self.reset()
self.fixed_policy_output = self.get_fixed_policy_output()
self.random_policy_output = self.get_fixed_policy_output()
self.policy_output_dim = len(self.fixed_policy_output)
self.policy_input_dim = len(self.state2policy())
self.max_traj_len = self.get_max_traj_len()
self.max_traj_len = self.get_max_traj_length()
# Set up proxy
self.proxy.setup(self.max_seq_length)
self.setup_proxy()

def get_actions_space(self):
def get_action_space(self):
"""
Constructs list with all possible actions
"""
Expand All @@ -104,7 +86,7 @@ def get_actions_space(self):
actions += actions_r
return actions

def get_max_traj_len(
def get_max_traj_length(
self,
):
return self.max_seq_length / self.min_word_len + 1
Expand Down Expand Up @@ -324,8 +306,8 @@ def get_mask_invalid_actions_forward(self, state=None, done=None):
if done is None:
done = self.done
if done:
return [True for _ in range(len(self.action_space) + 1)]
mask = [False for _ in range(len(self.action_space) + 1)]
return [True for _ in range(self.action_space_dim + 1)]
mask = [False for _ in range(self.action_space_dim + 1)]
seq_length = len(state)
if seq_length < self.min_seq_length:
mask[self.eos] = True
Expand All @@ -334,50 +316,6 @@ def get_mask_invalid_actions_forward(self, state=None, done=None):
mask[idx] = True
return mask

def no_eos_mask(self, state=None):
"""
Returns True if no eos action is allowed given state
"""
if state is None:
state = self.state.copy()
return len(state) < self.min_seq_length

def true_density(self, max_states=1e6):
"""
Computes the reward density (reward / sum(rewards)) of the whole space, if the
dimensionality is smaller than specified in the arguments.

Returns
-------
Tuple:
- normalized reward for each state
- states
- (un-normalized) reward)
"""
if self._true_density is not None:
return self._true_density
if self.n_alphabet**self.max_seq_length > max_states:
return (None, None, None)
state_all = np.int32(
list(
itertools.product(*[list(range(self.n_alphabet))] * self.max_seq_length)
)
)
traj_rewards, state_end = zip(
*[
(self.proxy(state), state)
for state in state_all
if len(self.get_parents(state, False)[0]) > 0 or sum(state) == 0
]
)
traj_rewards = np.array(traj_rewards)
self._true_density = (
traj_rewards / traj_rewards.sum(),
list(map(tuple, state_end)),
traj_rewards,
)
return self._true_density

def make_train_set(
self,
ntrain,
Expand Down Expand Up @@ -491,36 +429,3 @@ def make_test_set(
t1_all = time.time()
times["all"] += t1_all - t0_all
return df_test, times

@staticmethod
def np2df(test_path, al_init_length, al_queries_per_iter, pct_test, data_seed):
data_dict = np.load(test_path, allow_pickle=True).item()
letters = numbers2letters(data_dict["samples"])
df = pd.DataFrame(
{
"samples": letters,
"energies": data_dict["energies"],
"train": [False] * len(letters),
"test": [False] * len(letters),
}
)
# Split train and test section of init data set
rng = np.random.default_rng(data_seed)
indices = rng.permutation(al_init_length)
n_tt = int(pct_test * len(indices))
indices_tt = indices[:n_tt]
indices_tr = indices[n_tt:]
df.loc[indices_tt, "test"] = True
df.loc[indices_tr, "train"] = True
# Split train and test the section of each iteration to preserve splits
idx = al_init_length
iters_remaining = (len(df) - al_init_length) // al_queries_per_iter
indices = rng.permutation(al_queries_per_iter)
n_tt = int(pct_test * len(indices))
for it in range(iters_remaining):
indices_tt = indices[:n_tt] + idx
indices_tr = indices[n_tt:] + idx
df.loc[indices_tt, "test"] = True
df.loc[indices_tr, "train"] = True
idx += al_queries_per_iter
return df