diff --git a/config/env/base.yaml b/config/env/base.yaml index dd60b080f..785acee6b 100644 --- a/config/env/base.yaml +++ b/config/env/base.yaml @@ -3,14 +3,17 @@ _target_: gflownet.envs.base.GFlowNetEnv # Reward function: power or boltzmann # boltzmann: exp(-1.0 * reward_beta * proxy) # power: (-1.0 * proxy / reward_norm) ** self.reward_beta -reward_func: boltzmann +# identity: proxy +reward_func: identity +# Minimum reward +reward_min: 1e-8 # Beta parameter of the reward function reward_beta: 1.0 # Reward normalization for "power" reward function reward_norm: 1.0 # If > 0, reward_norm = reward_norm_std_mult * std(energies) +reward_norm_std_mult: 0.0 proxy_state_format: oracle -reward_norm_std_mult: 8 # Buffer buffer: replay_capacity: 10 diff --git a/config/env/grid.yaml b/config/env/grid.yaml index 304ca4943..24251a2b4 100644 --- a/config/env/grid.yaml +++ b/config/env/grid.yaml @@ -9,9 +9,10 @@ func: corners n_dim: 2 # Number of cells per dimension length: 3 -# Minimum and maximum number of steps in the action space -min_step_len: 1 -max_step_len: 1 +# Maximum increment per each dimension that can be done by one action +max_increment: 1 +# Maximum number of dimensions that can be incremented by one action +max_dim_per_action: 1 # Mapping coordinates cell_min: -1 cell_max: 1 diff --git a/config/env/torus.yaml b/config/env/torus.yaml index 1d6917568..a5959b79a 100644 --- a/config/env/torus.yaml +++ b/config/env/torus.yaml @@ -11,9 +11,10 @@ n_dim: 2 n_angles: 8 # Maximum number of rounds length_traj: 12 -# Minimum and maximum number of steps in the action space -min_step_len: 1 -max_step_len: 1 +# Maximum increment per each dimension that can be done by one action +max_increment: 1 +# Maximum number of dimensions that can be incremented by one action +max_dim_per_action: 1 # Buffer buffer: data_path: null diff --git a/config/env/torus_rounds.yaml b/config/env/torus_rounds.yaml deleted file mode 100644 index 1ca00a05c..000000000 --- a/config/env/torus_rounds.yaml +++ /dev/null @@ -1,21 +0,0 @@ -defaults: - - base - -_target_: gflownet.envs.torus.Torus - -id: torus -func: sincos -# Dimensions of hypertorus -n_dim: 2 -# Number of angles per dimension -n_angles: 4 -# Maximum number of rounds -max_rounds: 1 -# Minimum and maximum number of steps in the action space -min_step_len: 1 -max_step_len: 1 -# Buffer -buffer: - data_path: null - train: null - test: null diff --git a/config/logger/base.yaml b/config/logger/base.yaml index b3b587293..be668b9f1 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -1,9 +1,11 @@ -_target_: logger.Logger +_target_: gflownet.utils.logger.Logger do: online: False times: False +project_name: "GFlowNet" + # Train metrics train: period: 1 diff --git a/config/logger/wandb.yaml b/config/logger/wandb.yaml index c91b71440..5ccbd3de8 100644 --- a/config/logger/wandb.yaml +++ b/config/logger/wandb.yaml @@ -6,7 +6,5 @@ _target_: gflownet.utils.logger.Logger do: online: True -project_name: "GFlowNet" - tags: - gflownet diff --git a/config/tests.yaml b/config/tests.yaml new file mode 100644 index 000000000..95309c734 --- /dev/null +++ b/config/tests.yaml @@ -0,0 +1,27 @@ +defaults: + - _self_ + - env: grid + - gflownet: flowmatch + - proxy: uniform + - logger: base + - user: alex + +# Device +device: cpu +# Float precision +float_precision: 32 +# Number of objects to sample at the end of training +n_samples: 1 +# Random seeds +seed: 0 + +# Hydra config +hydra: + # See: https://hydra.cc/docs/configure_hydra/workdir/ + run: + dir: ${user.logdir.root}/${now:%Y-%m-%d_%H-%M-%S}_tests + job: + # See: https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir/ + # See: https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory/#disable-changing-current-working-dir-to-jobs-output-dir + chdir: True + diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 83b8f802e..76b725e3b 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -1,9 +1,9 @@ +from copy import deepcopy +from typing import List, Tuple + import numpy as np import numpy.typing as npt import torch - -from copy import deepcopy -from typing import List, Tuple from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus @@ -20,48 +20,17 @@ def __init__( self, path_to_dataset, url_to_dataset, - length_traj=1, - fixed_distribution=dict, - random_distribution=dict, - vonmises_min_concentration=1e-3, - env_id=None, - reward_beta=1, - reward_norm=1.0, - reward_norm_std_mult=0, - reward_func="boltzmann", - denorm_proxy=False, - energies_stats=None, - proxy=None, - oracle=None, - policy_encoding_dim_per_angle=None, - n_comp=3, **kwargs, ): - self.atom_positions_dataset = AtomPositionsDataset(path_to_dataset, url_to_dataset) + self.atom_positions_dataset = AtomPositionsDataset( + path_to_dataset, url_to_dataset + ) atom_positions = self.atom_positions_dataset.sample() self.conformer = ConformerBase( atom_positions, constants.ad_smiles, constants.ad_free_tas ) n_dim = len(self.conformer.freely_rotatable_tas) - super(AlanineDipeptide, self).__init__( - n_dim=n_dim, - length_traj=length_traj, - fixed_distribution=fixed_distribution, - random_distribution=random_distribution, - vonmises_min_concentration=vonmises_min_concentration, - env_id=env_id, - reward_beta=reward_beta, - reward_norm=reward_norm, - reward_norm_std_mult=reward_norm_std_mult, - reward_func=reward_func, - denorm_proxy=denorm_proxy, - energies_stats=energies_stats, - proxy=proxy, - oracle=oracle, - policy_encoding_dim_per_angle=policy_encoding_dim_per_angle, - n_comp=n_comp, - **kwargs, - ) + super().__init__(**kwargs) self.sync_conformer_with_state() def sync_conformer_with_state(self, state: List = None): @@ -71,13 +40,7 @@ def sync_conformer_with_state(self, state: List = None): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer - def copy(self): - # return an instance of the environment - return deepcopy(self) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> npt.NDArray: + def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: """ Prepares a batch of states in torch "GFlowNet format" for the oracle. """ @@ -88,16 +51,14 @@ def statetorch2proxy( np_states = states.cpu().numpy() return np_states[:, :-1] - def statebatch2proxy( - self, states: List[List] - ) -> npt.NDArray: + def statebatch2proxy(self, states: List[List]) -> npt.NDArray: """ Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where each state is a row of length n_dim with an angle in radians. The n_actions item is removed. """ return np.array(states)[:, :-1] - + def statetorch2oracle( self, states: TensorType["batch", "state_dim"] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: diff --git a/gflownet/envs/aptamers.py b/gflownet/envs/aptamers.py index 47a3812a5..425a2eb91 100644 --- a/gflownet/envs/aptamers.py +++ b/gflownet/envs/aptamers.py @@ -1,13 +1,15 @@ """ Classes to represent aptamers environments """ -from typing import List import itertools +import time +from typing import List + import numpy as np import numpy.typing as npt import pandas as pd + from gflownet.envs.base import GFlowNetEnv -import time class AptamerSeq(GFlowNetEnv): @@ -51,47 +53,27 @@ def __init__( n_alphabet=4, min_word_len=1, max_word_len=1, - proxy=None, - oracle=None, - reward_beta=1, - env_id=None, - energies_stats=None, - reward_norm=1.0, - reward_norm_std_mult=0.0, - reward_func="power", - denorm_proxy=False, **kwargs, ): - super(AptamerSeq, self).__init__( - env_id, - reward_beta, - reward_norm, - reward_norm_std_mult, - reward_func, - energies_stats, - denorm_proxy, - proxy, - oracle, - **kwargs, - ) + super().__init__() self.source = [] self.min_seq_length = min_seq_length self.max_seq_length = max_seq_length self.n_alphabet = n_alphabet self.min_word_len = min_word_len self.max_word_len = max_word_len - self.action_space = self.get_actions_space() - self.eos = len(self.action_space) + self.action_space = self.get_action_space() + self.eos = self.action_space_dim self.reset() self.fixed_policy_output = self.get_fixed_policy_output() self.random_policy_output = self.get_fixed_policy_output() self.policy_output_dim = len(self.fixed_policy_output) self.policy_input_dim = len(self.state2policy()) - self.max_traj_len = self.get_max_traj_len() + self.max_traj_len = self.get_max_traj_length() # Set up proxy - self.proxy.setup(self.max_seq_length) + self.setup_proxy() - def get_actions_space(self): + def get_action_space(self): """ Constructs list with all possible actions """ @@ -104,7 +86,7 @@ def get_actions_space(self): actions += actions_r return actions - def get_max_traj_len( + def get_max_traj_length( self, ): return self.max_seq_length / self.min_word_len + 1 @@ -324,8 +306,8 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): if done is None: done = self.done if done: - return [True for _ in range(len(self.action_space) + 1)] - mask = [False for _ in range(len(self.action_space) + 1)] + return [True for _ in range(self.action_space_dim + 1)] + mask = [False for _ in range(self.action_space_dim + 1)] seq_length = len(state) if seq_length < self.min_seq_length: mask[self.eos] = True @@ -334,50 +316,6 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): mask[idx] = True return mask - def no_eos_mask(self, state=None): - """ - Returns True if no eos action is allowed given state - """ - if state is None: - state = self.state.copy() - return len(state) < self.min_seq_length - - def true_density(self, max_states=1e6): - """ - Computes the reward density (reward / sum(rewards)) of the whole space, if the - dimensionality is smaller than specified in the arguments. - - Returns - ------- - Tuple: - - normalized reward for each state - - states - - (un-normalized) reward) - """ - if self._true_density is not None: - return self._true_density - if self.n_alphabet**self.max_seq_length > max_states: - return (None, None, None) - state_all = np.int32( - list( - itertools.product(*[list(range(self.n_alphabet))] * self.max_seq_length) - ) - ) - traj_rewards, state_end = zip( - *[ - (self.proxy(state), state) - for state in state_all - if len(self.get_parents(state, False)[0]) > 0 or sum(state) == 0 - ] - ) - traj_rewards = np.array(traj_rewards) - self._true_density = ( - traj_rewards / traj_rewards.sum(), - list(map(tuple, state_end)), - traj_rewards, - ) - return self._true_density - def make_train_set( self, ntrain, @@ -491,36 +429,3 @@ def make_test_set( t1_all = time.time() times["all"] += t1_all - t0_all return df_test, times - - @staticmethod - def np2df(test_path, al_init_length, al_queries_per_iter, pct_test, data_seed): - data_dict = np.load(test_path, allow_pickle=True).item() - letters = numbers2letters(data_dict["samples"]) - df = pd.DataFrame( - { - "samples": letters, - "energies": data_dict["energies"], - "train": [False] * len(letters), - "test": [False] * len(letters), - } - ) - # Split train and test section of init data set - rng = np.random.default_rng(data_seed) - indices = rng.permutation(al_init_length) - n_tt = int(pct_test * len(indices)) - indices_tt = indices[:n_tt] - indices_tr = indices[n_tt:] - df.loc[indices_tt, "test"] = True - df.loc[indices_tr, "train"] = True - # Split train and test the section of each iteration to preserve splits - idx = al_init_length - iters_remaining = (len(df) - al_init_length) // al_queries_per_iter - indices = rng.permutation(al_queries_per_iter) - n_tt = int(pct_test * len(indices)) - for it in range(iters_remaining): - indices_tt = indices[:n_tt] + idx - indices_tr = indices[n_tt:] + idx - df.loc[indices_tt, "test"] = True - df.loc[indices_tr, "train"] = True - idx += al_queries_per_iter - return df diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 97760d002..e4b33402a 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -2,14 +2,15 @@ Base class of GFlowNet environments """ from abc import abstractmethod -from typing import List, Tuple +from copy import deepcopy +from typing import List, Optional, Tuple, Union + import numpy as np import numpy.typing as npt -import pandas as pd import torch from torch.distributions import Categorical from torchtyping import TensorType -import pickle + from gflownet.utils.common import set_device, set_float_precision @@ -20,40 +21,43 @@ class GFlowNetEnv: def __init__( self, - device="cpu", - float_precision=32, - env_id=None, - reward_beta=1, - reward_norm=1.0, - reward_norm_std_mult=0, - reward_func="power", - energies_stats=None, - denorm_proxy=False, + device: str = "cpu", + float_precision: int = 32, + env_id: Union[int, str] = "env", + reward_min: float = 1e-8, + reward_beta: float = 1.0, + reward_norm: float = 1.0, + reward_norm_std_mult: float = 0.0, + reward_func: str = "identity", + energies_stats: List[int] = None, + denorm_proxy: bool = False, proxy=None, oracle=None, - proxy_state_format=None, + proxy_state_format: str = "oracle", + fixed_distribution: Optional[dict] = None, + random_distribution: Optional[dict] = None, **kwargs, ): + # Call reset() to set initial state, done, n_actions + self.reset() # Device - if isinstance(device, str): - self.device = set_device(device) - else: - self.device = device + self.device = set_device(device) # Float precision self.float = set_float_precision(float_precision) - # Environment - self.state = [] - self.done = False - self.n_actions = 0 - self.id = env_id - self.min_reward = 1e-8 + # Reward settings + self.min_reward = reward_min + assert self.min_reward > 0 self.reward_beta = reward_beta + assert self.reward_beta > 0 self.reward_norm = reward_norm + assert self.reward_norm > 0 self.reward_norm_std_mult = reward_norm_std_mult self.reward_func = reward_func self.energies_stats = energies_stats self.denorm_proxy = denorm_proxy + # Proxy and oracle self.proxy = proxy + self.setup_proxy() if oracle is None: self.oracle = self.proxy else: @@ -63,46 +67,227 @@ def __init__( else: self.proxy_factor = -1.0 self.proxy_state_format = proxy_state_format - self._true_density = None - self._z = None - self.action_space = [] - self.eos = len(self.action_space) + # Log SoftMax function self.logsoftmax = torch.nn.LogSoftmax(dim=1) - # Assertions - assert self.reward_norm > 0 - assert self.reward_beta > 0 - assert self.min_reward > 0 - - def copy(self): - # return an instance of the environment - return self.__class__(**self.__dict__) - - def set_energies_stats(self, energies_stats): - self.energies_stats = energies_stats - - def set_reward_norm(self, reward_norm): - self.reward_norm = reward_norm + # Action space + self.action_space = self.get_action_space() + self.action_space_torch = torch.tensor( + self.action_space, device=self.device, dtype=self.float + ) + self.action_space_dim = len(self.action_space) + # Max trajectory length + self.max_traj_length = self.get_max_traj_length() + # Policy outputs + self.fixed_policy_output = self.get_policy_output(fixed_distribution) + self.random_policy_output = self.get_policy_output(random_distribution) + self.policy_output_dim = len(self.fixed_policy_output) + self.policy_input_dim = len(self.state2policy()) @abstractmethod - def get_actions_space(self): + def get_action_space(self): """ Constructs list with all possible actions (excluding end of sequence) """ pass - def get_fixed_policy_output(self): + def actions2indices( + self, actions: TensorType["batch_size", "action_dim"] + ) -> TensorType["batch_size"]: """ - Defines the structure of the output of the policy model, from which an - action is to be determined or sampled, by returning a vector with a fixed - random policy. As a baseline, the fixed policy is uniform over the - dimensionality of the action space. + Returns the corresponding indices in the action space of the actions in a batch. """ - return np.ones(len(self.action_space)) + # Expand the action_space tensor: [batch_size, d_actions_space, action_dim] + action_space = torch.unsqueeze(self.action_space_torch, 0).expand( + actions.shape[0], -1, -1 + ) + # Expand the actions tensor: [batch_size, d_actions_space, action_dim] + actions = torch.unsqueeze(actions, 1).expand(-1, self.action_space_dim, -1) + # Take the indices at the d_actions_space dimension where all the elements in + # the action_dim dimension are True + return torch.where(torch.all(actions == action_space, dim=2))[1] - def get_max_traj_len( + def get_mask_invalid_actions_forward( self, - ): - return 1e3 + state: Optional[List] = None, + done: Optional[bool] = None, + ) -> List: + """ + Returns a list of length the action space with values: + - True if the forward action is invalid from the current state. + - False otherwise. + For continuous or hybrid environments, this mask corresponds to the discrete + part of the action space. + """ + return [False for _ in range(self.action_space_dim)] + + def get_mask_invalid_actions_backward( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + parents_a: Optional[List] = None, + ) -> List: + """ + Returns a list of length the action space with values: + - True if the backward action is invalid from the current state. + - False otherwise. + For continuous or hybrid environments, this mask corresponds to the discrete + part of the action space. + + The base implementation below should be common to all discrete spaces as it + relies on get_parents, which is environment-specific and must be implemented. + Continuous environments will probably need to implement its specific version of + this method. + """ + if parents_a is None: + _, parents_a = self.get_parents() + mask = [True for _ in range(self.action_space_dim)] + for pa in parents_a: + mask[self.action_space.index(pa)] = False + return mask + + def get_parents( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + action: Optional[Tuple] = None, + ) -> Tuple[List, List]: + """ + Determines all parents and actions that lead to state. + + In continuous environments, get_parents() should return only the parent from + which action leads to state. + + Args + ---- + state : list + Representation of a state + + done : bool + Whether the trajectory is done. If None, done is taken from instance. + + action : tuple + Last action performed + + Returns + ------- + parents : list + List of parents in state format + + actions : list + List of actions that lead to state for each parent in parents + """ + if state is None: + state = self.state.copy() + if done is None: + done = self.done + if done: + return [state], [(self.eos,)] + parents = [] + actions = [] + return parents, actions + + def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int], bool]: + """ + Executes step given an action. + + Args + ---- + action : tuple + Action from the action space. + + Returns + ------- + self.state : list + The sequence after executing the action + + action : int + Action index + + valid : bool + False, if the action is not allowed for the current state, e.g. stop at the + root state + """ + # If env is done, return invalid + if self.done: + return self.state, action, False + # If action not found in action space raise an error + if action not in self.action_space: + raise ValueError( + f"Tried to execute action {action} not present in action space." + ) + action_idx = self.action_space.index(action) + # If action is in invalid mask, exit immediately + if self.get_mask_invalid_actions_forward()[action_idx]: + return self.state, action, False + return None, None, None + + def sample_actions( + self, + policy_outputs: TensorType["n_states", "policy_output_dim"], + sampling_method: str = "policy", + mask_invalid_actions: TensorType["n_states", "policy_output_dim"] = None, + temperature_logits: float = 1.0, + loginf: float = 1000, + ) -> Tuple[List[Tuple], TensorType["n_states"]]: + """ + Samples a batch of actions from a batch of policy outputs. This implementation + is generally valid for all discrete environments but continuous environments + will likely have to implement its own. + """ + device = policy_outputs.device + ns_range = torch.arange(policy_outputs.shape[0]).to(device) + if sampling_method == "uniform": + logits = torch.ones(policy_outputs.shape).to(device) + elif sampling_method == "policy": + logits = policy_outputs + logits /= temperature_logits + if mask_invalid_actions is not None: + logits[mask_invalid_actions] = -loginf + action_indices = Categorical(logits=logits).sample() + logprobs = self.logsoftmax(logits)[ns_range, action_indices] + # Build actions + actions = [self.action_space[idx] for idx in action_indices] + return actions, logprobs + + def get_logprobs( + self, + policy_outputs: TensorType["n_states", "policy_output_dim"], + is_forward: bool, + actions: TensorType["n_states", "actions_dim"], + states_target: TensorType["n_states", "policy_input_dim"], + mask_invalid_actions: TensorType["batch_size", "policy_output_dim"] = None, + loginf: float = 1000, + ) -> TensorType["batch_size"]: + """ + Computes log probabilities of actions given policy outputs and actions. This + implementation is generally valid for all discrete environments but continuous + environments will likely have to implement its own. + """ + device = policy_outputs.device + ns_range = torch.arange(policy_outputs.shape[0]).to(device) + logits = policy_outputs + if mask_invalid_actions is not None: + logits[mask_invalid_actions] = -loginf + action_indices = ( + torch.tensor( + [self.action_space.index(tuple(action.tolist())) for action in actions] + ) + .to(int) + .to(device) + ) + logprobs = self.logsoftmax(logits)[ns_range, action_indices] + return logprobs + + def get_policy_output(self, params: Optional[dict] = None): + """ + Defines the structure of the output of the policy model, from which an + action is to be determined or sampled, by returning a vector with a fixed + random policy. As a baseline, the policy is uniform over the dimensionality of + the action space. + + Continuous environments will generally have to overwrite this method. + """ + return np.ones(self.action_space_dim) def state2proxy(self, state: List = None): """ @@ -129,8 +314,8 @@ def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: return np.array(states) def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: + self, states: TensorType["batch_size", "state_dim"] + ) -> TensorType["batch_size", "state_proxy_dim"]: """ Prepares a batch of states in torch "GFlowNet format" for the proxy. """ @@ -155,6 +340,57 @@ def statebatch2oracle(self, states: List[List]): """ return states + def statetorch2policy( + self, states: TensorType["batch_size", "state_dim"] + ) -> TensorType["batch_size", "policy_input_dim"]: + """ + Prepares a batch of states in torch "GFlowNet format" for the policy + """ + return states + + def state2policy(self, state=None): + """ + Converts a state into a format suitable for a machine learning model, such as a + one-hot encoding. + """ + if state is None: + state = self.state + return state + + def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: + """ + Converts a batch of states into a format suitable for a machine learning model, + such as a one-hot encoding. Returns a numpy array. + """ + return np.array(states) + + def policy2state(self, state_policy: List) -> List: + """ + Converts the model (e.g. one-hot encoding) version of a state given as + argument into a state. + """ + return state_policy + + def state2readable(self, state=None): + """ + Converts a state into human-readable representation. + """ + if state is None: + state = self.state + return str(state) + + def readable2state(self, readable): + """ + Converts a human-readable representation of a state into the standard format. + """ + return readable + + def traj2readable(self, traj=None): + """ + Converts a trajectory into a human-readable string. + """ + return str(traj).replace("(", "[").replace(")", "]").replace(",", "") + def reward(self, state=None, done=None): """ Computes the reward of a state @@ -180,7 +416,9 @@ def reward_batch(self, states: List[List], done=None): return rewards def reward_torchbatch( - self, states: TensorType["batch", "state_dim"], done: TensorType["batch"] = None + self, + states: TensorType["batch_size", "state_dim"], + done: TensorType["batch_size"] = None, ): """ Computes the rewards of a batch of states in "GFlownet format" @@ -224,7 +462,7 @@ def proxy2reward(self, proxy_vals): max=None, ) else: - raise NotImplemented + raise NotImplementedError def reward2proxy(self, reward): """ @@ -241,158 +479,38 @@ def reward2proxy(self, reward): elif self.reward_func == "identity": return self.proxy_factor * reward else: - raise NotImplemented - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_input_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the policy - """ - return states + raise NotImplementedError - def state2policy(self, state=None): - """ - Converts a state into a format suitable for a machine learning model, such as a - one-hot encoding. - """ - if state is None: - state = self.state - return state - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Converts a batch of states into a format suitable for a machine learning model, - such as a one-hot encoding. Returns a numpy array. - """ - return np.array(states) - - def policy2state(self, state_policy: List) -> List: - """ - Converts the model (e.g. one-hot encoding) version of a state given as - argument into a state. - """ - return state_policy - - def state2readable(self, state=None): - """ - Converts a state into human-readable representation. - """ - if state is None: - state = self.state - return str(state) - - def readable2state(self, readable): - """ - Converts a human-readable representation of a state into the standard format. - """ - return readable - - def traj2readable(self, traj=None): - """ - Converts a trajectory into a human-readable string. - """ - return str(traj).replace("(", "[").replace(")", "]").replace(",", "") - - def reset(self, env_id=None): + def reset(self, env_id: Union[int, str] = None): """ Resets the environment. """ - self.state = [] + self.state = self.source.copy() self.n_actions = 0 self.done = False self.id = env_id return self - def get_parents(self, state=None, done=None, action=None): + def set_state(self, state: List, done: Optional[bool] = False): """ - Determines all parents and actions that lead to state. - - Args - ---- - state : list - Representation of a state - - done : bool - Whether the trajectory is done. If None, done is taken from instance. - - action : tuple - Last action performed + Sets the state and done of an environment. + """ + self.state = state + self.done = done + return self - Returns - ------- - parents : list - List of parents in state format + def copy(self): + # return self.__class__(**self.__dict__) + return deepcopy(self) - actions : list - List of actions that lead to state for each parent in parents - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [state], [self.eos] - else: - parents = [] - actions = [] - return parents, actions + def set_energies_stats(self, energies_stats): + self.energies_stats = energies_stats - def sample_actions( - self, - policy_outputs: TensorType["n_states", "policy_output_dim"], - sampling_method: str = "policy", - mask_invalid_actions: TensorType["n_states", "policy_output_dim"] = None, - temperature_logits: float = 1.0, - loginf: float = 1000, - ) -> Tuple[List[Tuple], TensorType["n_states"]]: - """ - Samples a batch of actions from a batch of policy outputs. This implementation - is generally valid for all discrete environments. - """ - device = policy_outputs.device - ns_range = torch.arange(policy_outputs.shape[0]).to(device) - if sampling_method == "uniform": - logits = torch.ones(policy_outputs.shape).to(device) - elif sampling_method == "policy": - logits = policy_outputs - logits /= temperature_logits - if mask_invalid_actions is not None: - logits[mask_invalid_actions] = -loginf - action_indices = Categorical(logits=logits).sample() - logprobs = self.logsoftmax(logits)[ns_range, action_indices] - # Build actions - actions = [self.action_space[idx] for idx in action_indices] - return actions, logprobs + def set_reward_norm(self, reward_norm): + self.reward_norm = reward_norm - def get_logprobs( - self, - policy_outputs: TensorType["n_states", "policy_output_dim"], - is_forward: bool, - actions: TensorType["n_states", 2], - states_target: TensorType["n_states", "policy_input_dim"], - mask_invalid_actions: TensorType["batch_size", "policy_output_dim"] = None, - loginf: float = 1000, - ) -> TensorType["batch_size"]: - """ - Computes log probabilities of actions given policy outputs and actions. This - implementation is generally valid for all discrete environments. - """ - device = policy_outputs.device - ns_range = torch.arange(policy_outputs.shape[0]).to(device) - logits = policy_outputs - if mask_invalid_actions is not None: - logits[mask_invalid_actions] = -loginf - # TODO: fix need to convert to tuple: implement as in continuous - action_indices = ( - torch.tensor( - [self.action_space.index(tuple(action.tolist())) for action in actions] - ) - .to(int) - .to(device) - ) - logprobs = self.logsoftmax(logits)[ns_range, action_indices] - return logprobs + def get_max_traj_length(self): + return 1e3 def get_trajectories( self, traj_list, traj_actions_list, current_traj, current_actions @@ -433,318 +551,6 @@ def get_trajectories( ) return traj_list, traj_actions_list - def step(self, action_idx): - """ - Executes step given an action. - - Args - ---- - action_idx : int - Index of action in the action space. a == eos indicates "stop action" - - Returns - ------- - self.state : list - The sequence after executing the action - - action_idx : int - Action index - - valid : bool - False, if the action is not allowed for the current state, e.g. stop at the - root state - """ - if action < self.eos: - self.done = False - valid = True - else: - self.done = True - valid = True - self.n_actions += 1 - return self.state, action, valid - - def no_eos_mask(self, state=None): - """ - Returns True if no eos action is allowed given state - """ - if state is None: - state = self.state - return False - - def get_mask_invalid_actions_forward(self, state=None, done=None): - """ - Returns a vector of length the action space + 1: True if forward action is - invalid given the current state, False otherwise. - """ - mask = [False for _ in range(len(self.action_space))] - return mask - - def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): - """ - Returns a vector with the length of the discrete part of the action space + 1: - True if action is invalid going backward given the current state, False - otherwise. - """ - if parents_a is None: - _, parents_a = self.get_parents() - mask = [True for _ in range(len(self.action_space))] - for pa in parents_a: - mask[self.action_space.index(pa)] = False - return mask - - def set_state(self, state, done): - """ - Sets the state and done of an environment. - """ - self.state = state - self.done = done - return self - - def true_density(self): - """ - Computes the reward density (reward / sum(rewards)) of the whole space - - Returns - ------- - Tuple: - - normalized reward for each state - - un-normalized reward - - states - """ - return (None, None, None) - - @staticmethod - def np2df(*args): - """ - Args - ---- - """ - return None - - -class Buffer: - """ - Implements the functionality to manage various buffers of data: the records of - training samples, the train and test data sets, a replay buffer for training, etc. - """ - - def __init__( - self, - env, - make_train_test=False, - replay_capacity=0, - output_csv=None, - data_path=None, - train=None, - test=None, - logger=None, - **kwargs, - ): - self.logger = logger - self.env = env - self.replay_capacity = replay_capacity - self.main = pd.DataFrame(columns=["state", "traj", "reward", "energy", "iter"]) - self.replay = pd.DataFrame( - np.empty((self.replay_capacity, 5), dtype=object), - columns=["state", "traj", "reward", "energy", "iter"], - ) - self.replay.reward = pd.to_numeric(self.replay.reward) - self.replay.energy = pd.to_numeric(self.replay.energy) - self.replay.reward = [-1 for _ in range(self.replay_capacity)] - # Define train and test data sets - if train is not None and "type" in train: - self.train_type = train.type - else: - self.train_type = None - self.train, dict_tr = self.make_data_set(train) - if ( - self.train is not None - and "output_csv" in train - and train.output_csv is not None - ): - self.train.to_csv(train.output_csv) - if ( - dict_tr is not None - and "output_pkl" in train - and train.output_pkl is not None - ): - with open(train.output_pkl, "wb") as f: - pickle.dump(dict_tr, f) - self.train_pkl = train.output_pkl - else: - print( - """ - Important: offline trajectories will NOT be sampled. In order to sample - offline trajectories, the train configuration of the buffer should be - complete and feasible and an output pkl file should be defined in - env.buffer.train.output_pkl. - """ - ) - self.train_pkl = None - if test is not None and "type" in test: - self.test_type = test.type - else: - self.train_type = None - self.test, dict_tt = self.make_data_set(test) - if ( - self.test is not None - and "output_csv" in test - and test.output_csv is not None - ): - self.test.to_csv(test.output_csv) - if dict_tt is not None and "output_pkl" in test and test.output_pkl is not None: - with open(test.output_pkl, "wb") as f: - pickle.dump(dict_tt, f) - self.test_pkl = test.output_pkl - else: - print( - """ - Important: test metrics will NOT be computed. In order to compute - test metrics the test configuration of the buffer should be complete and - feasible and an output pkl file should be defined in - env.buffer.test.output_pkl. - """ - ) - self.test_pkl = None - # Compute buffer statistics - if self.train is not None: - ( - self.mean_tr, - self.std_tr, - self.min_tr, - self.max_tr, - self.max_norm_tr, - ) = self.compute_stats(self.train) - if self.test is not None: - self.mean_tt, self.std_tt, self.min_tt, self.max_tt, _ = self.compute_stats( - self.test - ) - - def add( - self, - states, - trajs, - rewards, - energies, - it, - buffer="main", - criterion="greater", - ): - if buffer == "main": - self.main = pd.concat( - [ - self.main, - pd.DataFrame( - { - "state": [self.env.state2readable(s) for s in states], - "traj": [self.env.traj2readable(p) for p in trajs], - "reward": rewards, - "energy": energies, - "iter": it, - } - ), - ], - axis=0, - join="outer", - ) - elif buffer == "replay" and self.replay_capacity > 0: - if criterion == "greater": - self.replay = self._add_greater(states, trajs, rewards, energies, it) - - def _add_greater( - self, - states, - trajs, - rewards, - energies, - it, - ): - rewards_old = self.replay["reward"].values - rewards_new = rewards.copy() - while np.max(rewards_new) > np.min(rewards_old): - idx_new_max = np.argmax(rewards_new) - readable_state = self.env.state2readable(states[idx_new_max]) - if not self.replay["state"].isin([readable_state]).any(): - self.replay.iloc[self.replay.reward.argmin()] = { - "state": self.env.state2readable(states[idx_new_max]), - "traj": self.env.traj2readable(trajs[idx_new_max]), - "reward": rewards[idx_new_max], - "energy": energies[idx_new_max], - "iter": it, - } - rewards_old = self.replay["reward"].values - rewards_new[idx_new_max] = -1 - return self.replay - - def make_data_set(self, config): - """ - Constructs a data set as a DataFrame according to the configuration. - """ - if config is None: - return None, None - elif "path" in config and config.path is not None: - path = self.logger.logdir / Path("data") / config.path - df = pd.read_csv(path, index_col=0) - # TODO: check if state2readable transformation is required. - return df - elif "type" not in config: - return None, None - elif config.type == "all" and hasattr(self.env, "get_all_terminating_states"): - samples = self.env.get_all_terminating_states() - elif ( - config.type == "grid" - and "n" in config - and hasattr(self.env, "get_grid_terminating_states") - ): - samples = self.env.get_grid_terminating_states(config.n) - elif ( - config.type == "uniform" - and "n" in config - and "seed" in config - and hasattr(self.env, "get_uniform_terminating_states") - ): - samples = self.env.get_uniform_terminating_states(config.n, config.seed) - else: - return None, None - energies = self.env.oracle(self.env.statebatch2oracle(samples)).tolist() - df = pd.DataFrame( - { - "samples": [self.env.state2readable(s) for s in samples], - "energies": energies, - } - ) - return df, {"x": samples, "energy": energies} - - def compute_stats(self, data): - mean_data = data["energies"].mean() - std_data = data["energies"].std() - min_data = data["energies"].min() - max_data = data["energies"].max() - data_zscores = (data["energies"] - mean_data) / std_data - max_norm_data = data_zscores.max() - return mean_data, std_data, min_data, max_data, max_norm_data - - def sample( - self, - ): - pass - - def __len__(self): - return self.capacity - - @property - def transitions(self): - pass - - def save( - self, - ): - pass - - @classmethod - def load(): - pass - - @property - def dummy(self): - pass + def setup_proxy(self): + if self.proxy: + self.proxy.setup(self) diff --git a/gflownet/envs/crystals.py b/gflownet/envs/crystals.py index 6a814f070..fecb9ece3 100644 --- a/gflownet/envs/crystals.py +++ b/gflownet/envs/crystals.py @@ -35,9 +35,11 @@ def __init__( Args ---------- elements : list or int - Elements that will be used for construction of crystal. Either list, in which case every value should - indicate the atomic number of an element, or int, in which case n consecutive atomic numbers will - be used. Note that we assume this will correspond to real atomic numbers, i.e. start from 1, not 0. + Elements that will be used for construction of crystal. Either list, in + which case every value should indicate the atomic number of an element, or + int, in which case n consecutive atomic numbers will be used. Note that we + assume this will correspond to real atomic numbers, i.e. start from 1, not + 0. max_diff_elem : int Maximum number of unique elements in the crystal @@ -52,35 +54,35 @@ def __init__( Maximum number of atoms that can be used to construct a crystal min_atom_i : int - Minimum number of elements of each used kind that needs to be used to construct a crystal + Minimum number of elements of each used kind that needs to be used to + construct a crystal max_atom_i : int - Maximum number of elements of each kind that can be used to construct a crystal + Maximum number of elements of each kind that can be used to construct a + crystal oxidation_states : (optional) dict - Mapping from ints (representing elements) to lists of different oxidation states + Mapping from ints (representing elements) to lists of different oxidation + states alphabet : (optional) dict - Mapping from ints (representing elements) to strings containing human-readable elements' names + Mapping from ints (representing elements) to strings containing + human-readable elements' names required_elements : (optional) list - List of elements that must be present in a crystal for it to represent a valid end state + List of elements that must be present in a crystal for it to represent a + valid end state """ - super().__init__(**kwargs) - if isinstance(elements, int): elements = [i + 1 for i in range(elements)] - if len(elements) != len(set(elements)): raise ValueError( f"Provided elements must be unique, detected {len(elements) - len(set(elements))} duplicates." ) - if any(e <= 0 for e in elements): raise ValueError( "Provided elements should be non-negative (assumed indexing from 1 for H)." ) - self.elements = sorted(elements) self.max_diff_elem = max_diff_elem self.min_diff_elem = min_diff_elem @@ -97,15 +99,15 @@ def __init__( self.required_elements = ( required_elements if required_elements is not None else [] ) - - self.source = [0 for _ in range(periodic_table)] self.elem2idx = {e: i for i, e in enumerate(self.elements)} self.idx2elem = {i: e for i, e in enumerate(self.elements)} - self.eos = -1 - self.action_space = self.get_actions_space() - self.reset() + # Source state: 0 atoms for all elements + self.source = [0 for _ in self.elements] + # End-of-sequence action + self.eos = (-1, -1) + super().__init__(**kwargs) - def get_actions_space(self): + def get_action_space(self): """ Constructs list with all possible actions. An action is described by a tuple (element, n), indicating that the count of element will be @@ -115,10 +117,10 @@ def get_actions_space(self): assert self.max_atom_i > self.min_atom_i valid_word_len = np.arange(self.min_atom_i, self.max_atom_i + 1) actions = [(element, n) for n in valid_word_len for element in self.elements] - actions.append((self.eos, 0)) + actions.append(self.eos) return actions - def get_max_traj_len(self): + def get_max_traj_length(self): return min(len(self.state), self.max_atoms // self.min_atom_i) def get_mask_invalid_actions(self, state=None, done=None): @@ -132,18 +134,18 @@ def get_mask_invalid_actions(self, state=None, done=None): done = self.done if done: - return [True for _ in range(len(self.action_space))] + return [True for _ in range(self.action_space_dim)] mask = [False for _ in self.action_space] state_elem = [self.idx2elem[i] for i, e in enumerate(state) if e > 0] n_state_atoms = sum(state) if n_state_atoms < self.min_atoms: - mask[self.eos] = True + mask[-1] = True if len(state_elem) < self.min_diff_elem: - mask[self.eos] = True + mask[-1] = True if any(r not in state_elem for r in self.required_elements): - mask[self.eos] = True + mask[-1] = True for idx, (element, n) in enumerate(self.action_space[:-1]): if state[self.elem2idx[element]] > 0: @@ -170,7 +172,8 @@ def state2oracle(self, state: List = None) -> Tensor: Returns ---- oracle_state : Tensor - Tensor containing # of Li atoms, total # of atoms, and fractions of individual elements + Tensor containing # of Li atoms, total # of atoms, and fractions of + individual elements """ if state is None: state = self.state @@ -237,9 +240,9 @@ def get_parents(self, state=None, done=None, actions=None): Args ---- state : list - Representation of a state as a list of length equal to that of self.elements, - where i-th value contains the count of atoms for i-th element, from 0 to - self.max_atoms_i. + Representation of a state as a list of length equal to that of + self.elements, where i-th value contains the count of atoms for i-th + element, from 0 to self.max_atoms_i. done : bool Whether the trajectory is done. If None, done is taken from instance. @@ -280,7 +283,7 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo Args ---- action : tuple - Action to be executed. See: get_actions_space() + Action to be executed. See: get_action_space() Returns ------- @@ -293,11 +296,14 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo valid : bool False, if the action is not allowed for the current state. """ + # If done, return invalid + if self.done: + return self.state, action, False # If only possible action is eos, then force eos if sum(self.state) == self.max_atoms: self.done = True self.n_actions += 1 - return self.state, (self.eos, 0), True + return self.state, self.eos, True # If action not found in action space raise an error action_idx = None for i, a in enumerate(self.action_space): @@ -312,7 +318,7 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo if self.get_mask_invalid_actions()[action_idx]: return self.state, action, False # If action is not eos, then perform action - if action[0] != self.eos: + if action != self.eos: atomic_number, num = action idx = self.elem2idx[atomic_number] state_next = self.state[:] @@ -326,7 +332,7 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo return self.state, action, valid # If action is eos, then perform eos else: - if self.get_mask_invalid_actions()[self.eos]: + if self.get_mask_invalid_actions()[-1]: valid = False else: if self._can_produce_neutral_charge(): @@ -335,7 +341,7 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo self.n_actions += 1 else: valid = False - return self.state, (self.eos, 0), valid + return self.state, self.eos, valid def _can_produce_neutral_charge(self, state: Optional[List[int]] = None) -> bool: """ diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index 94bbff091..a504a5f22 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -1,16 +1,18 @@ """ Classes to represent hyper-torus environments """ -from typing import List, Tuple import itertools +from typing import List, Tuple + import numpy as np import numpy.typing as npt import pandas as pd import torch -from gflownet.envs.htorus import HybridTorus -from torch.distributions import Categorical, Uniform, VonMises, MixtureSameFamily +from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises from torchtyping import TensorType +from gflownet.envs.htorus import HybridTorus + class ContinuousTorus(HybridTorus): """ @@ -33,15 +35,15 @@ class ContinuousTorus(HybridTorus): def __init__(self, **kwargs): super().__init__(**kwargs) - def get_actions_space(self): + def get_action_space(self): """ - The actions are tuples of length 2 * n_dim, where positions d and d+1 in the - tuple correspond to dimension d and the increment of dimension d, - respectively. EOS is indicated by a tuple whose first element ins self.eos. + The actions are tuples of length n_dim, where the value at position d indicates + the increment of dimension d. EOS is indicated by increments of np.inf for all + dimensions. """ - pairs = [(dim, 0.0) for dim in range(self.n_dim)] - actions = [tuple([el for pair in pairs for el in pair])] - actions += [tuple([self.eos] + [0.0 for _ in range(self.n_dim * 2 - 1)])] + self.eos = tuple([np.inf for _ in range(self.n_dim)]) + generic_action = tuple([0.0 for _ in range(self.n_dim)]) + actions = [generic_action, self.eos] return actions def get_policy_output(self, params: dict): @@ -64,8 +66,8 @@ def get_policy_output(self, params: dict): - d * c * 3 + 2: log concentration of Von Mises distribution for dim. d, comp. c """ policy_output = np.ones(self.n_dim * self.n_comp * 3) - policy_output[1::3] = params.vonmises_mean - policy_output[2::3] = params.vonmises_concentration + policy_output[1::3] = params["vonmises_mean"] + policy_output[2::3] = params["vonmises_concentration"] return policy_output def get_mask_invalid_actions_forward(self, state=None, done=None): @@ -130,9 +132,12 @@ def get_parents( if done is None: done = self.done if done: - return [state], [self.action_space[-1]] + return [state], [self.eos] + # If source state + elif state[-1] == 0: + return [], [] else: - for dim, angle in zip(action[0::2], action[1::2]): + for dim, angle in enumerate(action): state[int(dim)] = (state[int(dim)] - angle) % (2 * np.pi) state[-1] -= 1 parents = [state] @@ -142,7 +147,7 @@ def sample_actions( self, policy_outputs: TensorType["n_states", "policy_output_dim"], sampling_method: str = "policy", - mask_stop_actions: TensorType["n_states", "1"] = None, + mask_invalid_actions: TensorType["n_states", "1"] = None, temperature_logits: float = 1.0, loginf: float = 1000, ) -> Tuple[List[Tuple], TensorType["n_states"]]: @@ -150,7 +155,7 @@ def sample_actions( Samples a batch of actions from a batch of policy outputs. """ device = policy_outputs.device - mask_states_sample = ~mask_stop_actions.flatten() + mask_states_sample = ~mask_invalid_actions.flatten() n_states = policy_outputs.shape[0] # Sample angle increments angles = torch.zeros(n_states, self.n_dim).to(device) @@ -183,16 +188,10 @@ def sample_actions( ) logprobs = torch.sum(logprobs, axis=1) # Build actions - actions_tensor = ( - torch.repeat_interleave(torch.arange(0, self.n_dim), 2) - .repeat(n_states, 1) - .to(dtype=self.float, device=device) + actions_tensor = torch.inf * torch.ones( + angles.shape, dtype=self.float, device=device ) - actions_tensor[mask_states_sample, 1::2] = angles[mask_states_sample] - actions_tensor[mask_stop_actions.flatten()] = torch.zeros( - actions_tensor.shape[1] - ).to(actions_tensor) - actions_tensor[mask_stop_actions.flatten(), 0] = 2.0 + actions_tensor[mask_states_sample, :] = angles[mask_states_sample] actions = [tuple(a.tolist()) for a in actions_tensor] return actions, logprobs @@ -200,18 +199,17 @@ def get_logprobs( self, policy_outputs: TensorType["n_states", "policy_output_dim"], is_forward: bool, - actions: TensorType["n_states", 2], + actions: TensorType["n_states", "n_dim"], states_target: TensorType["n_states", "policy_input_dim"], - mask_stop_actions: TensorType["n_states", "1"] = None, + mask_invalid_actions: TensorType["n_states", "1"] = None, loginf: float = 1000, ) -> TensorType["batch_size"]: """ Computes log probabilities of actions given policy outputs and actions. """ device = policy_outputs.device - mask_states_sample = ~mask_stop_actions.flatten() + mask_states_sample = ~mask_invalid_actions.flatten() n_states = policy_outputs.shape[0] - angles = actions[:, 1::2] logprobs = torch.zeros(n_states, self.n_dim).to(device) if torch.any(mask_states_sample): mix_logits = policy_outputs[mask_states_sample, 0::3].reshape( @@ -230,23 +228,20 @@ def get_logprobs( ) distr_angles = MixtureSameFamily(mix, vonmises) logprobs[mask_states_sample] = distr_angles.log_prob( - angles[mask_states_sample] + actions[mask_states_sample] ) logprobs = torch.sum(logprobs, axis=1) return logprobs - def step( - self, action: Tuple[int, float] - ) -> Tuple[List[float], Tuple[int, float], bool]: + def step(self, action: Tuple[float]) -> Tuple[List[float], Tuple[int, float], bool]: """ Executes step given an action. Args ---- action : tuple - Action to be executed. An action is a tuple with either: - - (self.eos, 0.0) with two values: - (dimension, magnitude). + Action to be executed. An action is a vector where the value at position d + indicates the increment in the angle at dimension d. Returns ------- @@ -267,15 +262,15 @@ def step( elif self.n_actions == self.length_traj: self.done = True self.n_actions += 1 - return self.state, self.action_space[-1], True - # If action is not eos, then perform action - elif action[0] != self.eos: + return self.state, self.eos, True + # If action is eos, then it is invalid + elif action == self.eos: + return self.state, action, False + # Otherwise perform action + else: self.n_actions += 1 - for dim, angle in zip(action[0::2], action[1::2]): + for dim, angle in enumerate(action): self.state[int(dim)] += angle self.state[int(dim)] = self.state[int(dim)] % (2 * np.pi) self.state[-1] = self.n_actions return self.state, action, True - # If action is eos, then it is invalid - else: - return self.state, action, False diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 5f8485a5a..3ab922339 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -1,27 +1,46 @@ """ Classes to represent a hyper-grid environments """ -from typing import List, Tuple import itertools +from typing import List, Optional, Tuple + import numpy as np import numpy.typing as npt import torch from torchtyping import TensorType + from gflownet.envs.base import GFlowNetEnv class Grid(GFlowNetEnv): """ - Hyper-grid environment + Hyper-grid environment: A grid with n_dim dimensions and length cells per + dimensions. + + The state space is the entire grid and each state is represented by the vector of + coordinates of each dimensions. For example, in 3D, the origin will be at [0, 0, 0] + and after incrementing dimension 0 by 2, dimension 1 by 3 and dimension 3 by 1, the + state would be [2, 3, 1]. + + The action space is the increment to be applied to each dimension. For instance, + (0, 0, 1) will increment dimension 2 by 1 and the action that goes from [1, 1, 1] + to [2, 3, 1] is (1, 2, 0). Attributes ---------- - ndim : int + n_dim : int Dimensionality of the grid length : int Size of the grid (cells per dimension) + max_increment : int + Maximum increment of each dimension by the actions. + + max_dim_per_action : int + Maximum number of dimensions to increment per action. If -1, then + max_dim_per_action is set to n_dim. + cell_min : float Lower bound of the cells range @@ -31,54 +50,65 @@ class Grid(GFlowNetEnv): def __init__( self, - n_dim=2, - length=3, - min_step_len=1, - max_step_len=1, - cell_min=-1, - cell_max=1, + n_dim: int = 2, + length: int = 3, + max_increment: int = 1, + max_dim_per_action: int = 1, + cell_min: float = -1, + cell_max: float = 1, **kwargs, ): - super().__init__(**kwargs) + assert n_dim > 0 + assert length > 1 + assert max_increment > 0 + assert max_dim_per_action == -1 or max_dim_per_action > 0 self.n_dim = n_dim - self.eos = self.n_dim - self.source = [0 for _ in range(self.n_dim)] self.length = length - self.min_step_len = min_step_len - self.max_step_len = max_step_len + self.max_increment = max_increment + if max_dim_per_action == -1: + max_dim_per_action = self.n_dim + self.max_dim_per_action = max_dim_per_action self.cells = np.linspace(cell_min, cell_max, length) - self.reset() - self.action_space = self.get_actions_space() - self.fixed_policy_output = self.get_fixed_policy_output() - self.random_policy_output = self.get_fixed_policy_output() - self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) + # Source state: position 0 at all dimensions + self.source = [0 for _ in range(self.n_dim)] + # End-of-sequence action + self.eos = tuple([0 for _ in range(self.n_dim)]) + # Base class init + super().__init__(**kwargs) + # Proxy format + # TODO: assess if really needed if self.proxy_state_format == "ohe": self.statebatch2proxy = self.statebatch2policy elif self.proxy_state_format == "oracle": self.statebatch2proxy = self.statebatch2oracle self.statetorch2proxy = self.statetorch2oracle - # Set up proxy - self.proxy.n_dim = self.n_dim - self.proxy.setup() - def get_actions_space(self): + def get_action_space(self): """ - Constructs list with all possible actions, including eos. + Constructs list with all possible actions, including eos. An action is + represented by a vector of length n_dim where each index d indicates the + increment to apply to dimension d of the hyper-grid. """ - valid_steplens = np.arange(self.min_step_len, self.max_step_len + 1) - dims = [a for a in range(self.n_dim)] + increments = [el for el in range(self.max_increment + 1)] actions = [] - for r in valid_steplens: - actions_r = [el for el in itertools.product(dims, repeat=r)] - actions += actions_r - actions += [(self.eos,)] + for action in itertools.product(increments, repeat=self.n_dim): + if ( + sum(action) != 0 + and len([el for el in action if el > 0]) <= self.max_dim_per_action + ): + actions.append(tuple(action)) + actions.append(self.eos) return actions - def get_mask_invalid_actions_forward(self, state=None, done=None): + def get_mask_invalid_actions_forward( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + ) -> List: """ - Returns a vector of length the action space + 1: True if forward action is - invalid given the current state, False otherwise. + Returns a list of length the action space with values: + - True if the forward action is invalid from the current state. + - False otherwise. """ if state is None: state = self.state.copy() @@ -87,38 +117,21 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): if done: return [True for _ in range(self.policy_output_dim)] mask = [False for _ in range(self.policy_output_dim)] - for idx, a in enumerate(self.action_space[:-1]): - for d in a: - if state[d] + 1 >= self.length: - mask[idx] = True - break + for idx, action in enumerate(self.action_space[:-1]): + child = state.copy() + for dim, incr in enumerate(action): + child[dim] += incr + if any(el >= self.length for el in child): + mask[idx] = True return mask - def true_density(self): - # Return pre-computed true density if already stored - if self._true_density is not None: - return self._true_density - # Calculate true density - all_states = np.int32( - list(itertools.product(*[list(range(self.length))] * self.n_dim)) - ) - state_mask = np.array( - [len(self.get_parents(s, False)[0]) > 0 or sum(s) == 0 for s in all_states] - ) - all_oracle = self.state2oracle(all_states) - rewards = self.oracle(all_oracle)[state_mask] - self._true_density = ( - rewards / rewards.sum(), - rewards, - list(map(tuple, all_states[state_mask])), - ) - return self._true_density - - def state2oracle(self, state: List = None): + def state2oracle(self, state: List = None) -> List: """ Prepares a state in "GFlowNet format" for the oracles: a list of length n_dim with values in the range [cell_min, cell_max] for each state. + See: state2policy() + Args ---- state : list @@ -127,16 +140,22 @@ def state2oracle(self, state: List = None): if state is None: state = self.state.copy() return ( - self.state2policy(state).reshape((self.n_dim, self.length)) - * self.cells[None, :] - ).sum(axis=1) + ( + np.array(self.state2policy(state)).reshape((self.n_dim, self.length)) + * self.cells[None, :] + ) + .sum(axis=1) + .tolist() + ) def statebatch2oracle( self, states: List[List] ) -> TensorType["batch", "state_oracle_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracles: a list of - length n_dim with values in the range [cell_min, cell_max] for each state. + Prepares a batch of states in "GFlowNet format" for the oracles: each state is + a vector of length n_dim with values in the range [cell_min, cell_max]. + + See: statetorch2oracle() Args ---- @@ -151,7 +170,10 @@ def statetorch2oracle( self, states: TensorType["batch", "state_dim"] ) -> TensorType["batch", "state_oracle_dim"]: """ - Prepares a batch of states in torch "GFlowNet format" for the oracle. + Prepares a batch of states in "GFlowNet format" for the oracles: each state is + a vector of length n_dim with values in the range [cell_min, cell_max]. + + See: statetorch2policy() """ return ( self.statetorch2policy(states).reshape( @@ -239,17 +261,12 @@ def state2readable(self, state, alphabet={}): """ return str(state).replace("(", "[").replace(")", "]").replace(",", "") - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = self.source.copy() - self.n_actions = 0 - self.done = False - self.id = env_id - return self - - def get_parents(self, state=None, done=None, action=None): + def get_parents( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + action: Optional[Tuple] = None, + ) -> Tuple[List, List]: """ Determines all parents and actions that lead to state. @@ -278,20 +295,20 @@ def get_parents(self, state=None, done=None, action=None): if done is None: done = self.done if done: - return [state], [(self.eos,)] + return [state], [self.eos] else: parents = [] actions = [] - for idx, a in enumerate(self.action_space[:-1]): - state_aux = state.copy() - for a_sub in a: - if state_aux[a_sub] > 0: - state_aux[a_sub] -= 1 + for idx, action in enumerate(self.action_space[:-1]): + parent = state.copy() + for dim, incr in enumerate(action): + if parent[dim] - incr >= 0: + parent[dim] -= incr else: break else: - parents.append(state_aux) - actions.append(a) + parents.append(parent) + actions.append(action) return parents, actions def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int], bool]: @@ -315,19 +332,31 @@ def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int], bool]: valid : bool False, if the action is not allowed for the current state. """ + # If done, return invalid if self.done: return self.state, action, False + # If action not found in action space raise an error + if action not in self.action_space: + raise ValueError( + f"Tried to execute action {action} not present in action space." + ) + else: + action_idx = self.action_space.index(action) + # If action is in invalid mask, return invalid + if self.get_mask_invalid_actions_forward()[action_idx]: + return self.state, action, False + # TODO: simplify by relying on mask # If only possible action is eos, then force eos # All dimensions are at the maximum length if all([s == self.length - 1 for s in self.state]): self.done = True self.n_actions += 1 - return self.state, (self.eos,), True + return self.state, self.eos, True # If action is not eos, then perform action - elif action[0] != self.eos: + elif action != self.eos: state_next = self.state.copy() - for a in action: - state_next[a] += 1 + for dim, incr in enumerate(action): + state_next[dim] += incr if any([s >= self.length for s in state_next]): valid = False else: @@ -339,7 +368,10 @@ def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int], bool]: else: self.done = True self.n_actions += 1 - return self.state, (self.eos,), True + return self.state, self.eos, True + + def get_max_traj_length(self): + return self.n_dim * self.length def get_all_terminating_states(self) -> List[List]: all_x = np.int32( diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index c2cb7a1b8..3430d7f38 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -1,19 +1,22 @@ """ Classes to represent hyper-torus environments """ +import itertools +import re from copy import deepcopy from typing import List, Tuple -import itertools + +import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pandas as pd -import matplotlib.pyplot as plt import torch -from gflownet.utils.common import torch2np -from gflownet.envs.base import GFlowNetEnv -from torch.distributions import Categorical, Uniform, VonMises, Bernoulli -from torchtyping import TensorType from sklearn.neighbors import KernelDensity +from torch.distributions import Bernoulli, Categorical, Uniform, VonMises +from torchtyping import TensorType + +from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import torch2np class HybridTorus(GFlowNetEnv): @@ -36,22 +39,29 @@ class HybridTorus(GFlowNetEnv): def __init__( self, - n_dim=2, - length_traj=1, - n_comp=1, - policy_encoding_dim_per_angle=None, - do_nonzero_source_prob=True, - fixed_distribution=dict, - random_distribution=dict, - vonmises_min_concentration=1e-3, + n_dim: int = 2, + length_traj: int = 1, + n_comp: int = 1, + policy_encoding_dim_per_angle: int = None, + do_nonzero_source_prob: bool = True, + vonmises_min_concentration: float = 1e-3, + fixed_distribution: dict = { + "vonmises_mean": 0.0, + "vonmises_concentration": 0.5, + }, + random_distribution: dict = { + "vonmises_mean": 0.0, + "vonmises_concentration": 0.001, + }, **kwargs, ): - super().__init__(**kwargs) - self.policy_encoding_dim_per_angle = policy_encoding_dim_per_angle + assert n_dim > 0 + assert length_traj > 0 + assert n_comp > 0 self.continuous = True self.n_dim = n_dim - self.eos = self.n_dim self.length_traj = length_traj + self.policy_encoding_dim_per_angle = policy_encoding_dim_per_angle # Parameters of fixed policy distribution self.n_comp = n_comp if do_nonzero_source_prob: @@ -59,35 +69,35 @@ def __init__( else: self.n_params_per_dim = 3 self.vonmises_min_concentration = vonmises_min_concentration - # Initialize angles and state attributes + # Source state: position 0 at all dimensions and number of actions 0 self.source_angles = [0.0 for _ in range(self.n_dim)] - # States are the concatenation of the angle state and number of actions self.source = self.source_angles + [0] - self.reset() - self.action_space = self.get_actions_space() - self.fixed_policy_output = self.get_policy_output(fixed_distribution) - self.random_policy_output = self.get_policy_output(random_distribution) - self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) - self.logsoftmax = torch.nn.LogSoftmax(dim=1) - # Oracle + # End-of-sequence action: (n_dim, None) + self.eos = (self.n_dim, 0) + # TODO: assess if really needed self.state2oracle = self.state2proxy self.statebatch2oracle = self.statebatch2proxy - # Setup proxy - self.proxy.set_n_dim(self.n_dim) - - def copy(self): - return deepcopy(self) + # Base class init + super().__init__( + fixed_distribution=fixed_distribution, + random_distribution=random_distribution, + **kwargs, + ) - def get_actions_space(self): + def get_action_space(self): """ - Constructs list with all possible actions. The actions are tuples with two - values: (dimension, magnitude) where dimension indicates the index of the - dimension on which the action is to be performed and magnitude indicates the - increment of the angle in radians. + Since this is a hybrid (continuous/discrete) environment, this method + constructs a list with the discrete actions. + + The actions are tuples with two values: (dimension, magnitude) where dimension + indicates the index of the dimension on which the action is to be performed and + magnitude indicates the increment of the angle in radians. + + The (discrete) action space is then one tuple per dimension (with 0 increment), + plus the EOS action. """ - actions = [(d, None) for d in range(self.n_dim)] - actions += [(self.eos, 0.0)] + actions = [(d, 0) for d in range(self.n_dim)] + actions.append(self.eos) return actions def get_policy_output(self, params: dict): @@ -117,13 +127,13 @@ def get_policy_output(self, params: dict): with d in [0, ..., D] """ policy_output = np.ones(self.n_dim * self.n_params_per_dim + 1) - policy_output[1 :: self.n_params_per_dim] = params.vonmises_mean - policy_output[2 :: self.n_params_per_dim] = params.vonmises_concentration + policy_output[1 :: self.n_params_per_dim] = params["vonmises_mean"] + policy_output[2 :: self.n_params_per_dim] = params["vonmises_concentration"] return policy_output def get_mask_invalid_actions_forward(self, state=None, done=None): """ - Returns a vector with the length of the discrete part of the action space + 1: + Returns a vector with the length of the discrete part of the action space: True if action is invalid going forward given the current state, False otherwise. """ @@ -132,18 +142,18 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): if done is None: done = self.done if done: - return [True for _ in range(len(self.action_space))] + return [True for _ in range(self.action_space_dim)] if state[-1] >= self.length_traj: - mask = [True for _ in range(len(self.action_space))] + mask = [True for _ in range(self.action_space_dim)] mask[-1] = False else: - mask = [False for _ in range(len(self.action_space))] + mask = [False for _ in range(self.action_space_dim)] mask[-1] = True return mask def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): """ - Returns a vector with the length of the discrete part of the action space + 1: + Returns a vector with the length of the discrete part of the action space: True if action is invalid going backward given the current state, False otherwise. """ @@ -152,15 +162,15 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non if done is None: done = self.done if done: - mask = [True for _ in range(len(self.action_space))] + mask = [True for _ in range(self.action_space_dim)] mask[-1] = False else: - mask = [False for _ in range(len(self.action_space))] + mask = [False for _ in range(self.action_space_dim)] mask[-1] = True # Catch cases where it would not be possible to reach the initial state noninit_states = [s for s, ss in zip(state[:-1], self.source_angles) if s != ss] if len(noninit_states) > state[-1]: - print("This point in the code should never be reached!") + raise ValueError("This point in the code should never be reached!") elif len(noninit_states) == state[-1] and len(noninit_states) >= state[-1] - 1: mask = [ True if s == ss else m @@ -168,22 +178,6 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non ] + [mask[-1]] return mask - def true_density(self): - # TODO - # Return pre-computed true density if already stored - if self._true_density is not None: - return self._true_density - # Calculate true density - all_x = self.get_all_terminating_states() - all_oracle = self.state2oracle(all_x) - rewards = self.oracle(all_oracle) - self._true_density = ( - rewards / rewards.sum(), - rewards, - list(map(tuple, all_x)), - ) - return self._true_density - def statebatch2proxy( self, states: List[List] ) -> TensorType["batch", "state_proxy_dim"]: @@ -254,9 +248,7 @@ def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: ): step = states[:, -1] code_half_size = self.policy_encoding_dim_per_angle // 2 - int_coeff = np.tile( - np.arange(1, code_half_size + 1), states.shape[-1] - 1 - ) + int_coeff = np.tile(np.arange(1, code_half_size + 1), states.shape[-1] - 1) encoding = ( np.repeat(states[:, :-1], repeats=code_half_size, axis=1) * int_coeff ) @@ -291,21 +283,17 @@ def readable2state(self, readable: str) -> List: Converts a human-readable string representing a state into a state as a list of positions. Angles are converted back to radians. """ + # Preprocess + pattern = re.compile(r"\s+") + readable = re.sub(pattern, " ", readable) + readable = readable.replace(" ]", "]") + readable = readable.replace(" [", "[") + # Process pair = readable.split(" | ") angles = [np.float32(el) * np.pi / 180 for el in pair[0].strip("[]").split(" ")] n_actions = [int(pair[1])] return angles + n_actions - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = self.source.copy() - self.n_actions = 0 - self.done = False - self.id = env_id - return self - def get_parents( self, state: List = None, done: bool = None, action: Tuple[int, float] = None ) -> Tuple[List[List], List[Tuple[int, float]]]: @@ -339,9 +327,13 @@ def get_parents( if done is None: done = self.done if done: - return [state], [(self.eos, 0.0)] + return [state], [self.eos] + # If source state + elif state[-1] == 0: + return [], [] else: - state[action[0]] = (state[action[0]] - action[1]) % (2 * np.pi) + dim, incr = action + state[dim] = (state[dim] - incr) % (2 * np.pi) state[-1] -= 1 parents = [state] return parents, [action] @@ -371,8 +363,8 @@ def sample_actions( dimensions = Categorical(logits=logits_dims).sample() logprobs_dim = self.logsoftmax(logits_dims)[ns_range, dimensions] # Sample angle increments - ns_range_noeos = ns_range[dimensions != self.eos] - dimensions_noeos = dimensions[dimensions != self.eos] + ns_range_noeos = ns_range[dimensions != self.eos[0]] + dimensions_noeos = dimensions[dimensions != self.eos[0]] angles = torch.zeros(n_states).to(device) logprobs_angles = torch.zeros(n_states).to(device) if len(dimensions_noeos) > 0: @@ -449,7 +441,7 @@ def get_logprobs( angledim_ne_source = torch.ne( states_target[ns_range, dimensions], source_aux[dimensions] ) - noeos = torch.ne(dimensions, self.eos) + noeos = torch.ne(dimensions, self.eos[0]) nofix_indices = torch.logical_and( torch.logical_or(nsource_ne_nsteps, angledim_ne_source) | is_forward, noeos ) @@ -507,6 +499,7 @@ def step( False, if the action is not allowed for the current state, e.g. stop at the root state """ + # If done, return invalid if self.done: return self.state, action, False # If only possible action is eos, then force eos @@ -514,17 +507,21 @@ def step( elif self.n_actions == self.length_traj: self.done = True self.n_actions += 1 - return self.state, (self.eos, 0.0), True - # If action is not eos, then perform action - elif action[0] != self.eos: + return self.state, self.eos, True + # If action is eos, then it is invalid + elif action == self.eos: + return self.state, action, False + # Otherwise perform action + else: + dim, incr = action self.n_actions += 1 - self.state[action[0]] += action[1] - self.state[action[0]] = self.state[action[0]] % (2 * np.pi) + self.state[dim] += incr + self.state[dim] = self.state[dim] % (2 * np.pi) self.state[-1] = self.n_actions return self.state, action, True - # If action is eos, then it is invalid - else: - return self.state, action, False + + def copy(self): + return deepcopy(self) def get_grid_terminating_states(self, n_states: int) -> List[List]: n_per_dim = int(np.ceil(n_states ** (1 / self.n_dim))) @@ -581,16 +578,26 @@ def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): return kde def plot_reward_samples( - self, samples, alpha=0.5, low=-np.pi * 0.5, high=2.5 * np.pi, dpi=150, limit_n_samples=500, **kwargs + self, + samples, + alpha=0.5, + low=-np.pi * 0.5, + high=2.5 * np.pi, + dpi=150, + limit_n_samples=500, + **kwargs, ): x = np.linspace(low, high, 201) y = np.linspace(low, high, 201) xx, yy = np.meshgrid(x, y) X = np.stack([xx, yy], axis=-1) - samples_mesh = torch.tensor( - X.reshape(-1, 2), dtype=self.float) - states_mesh = torch.cat([samples_mesh, torch.ones(samples_mesh.shape[0], 1)], 1).to(self.device) - rewards = torch2np(self.proxy2reward(self.proxy(self.statetorch2proxy(states_mesh)))) + samples_mesh = torch.tensor(X.reshape(-1, 2), dtype=self.float) + states_mesh = torch.cat( + [samples_mesh, torch.ones(samples_mesh.shape[0], 1)], 1 + ).to(self.device) + rewards = torch2np( + self.proxy2reward(self.proxy(self.statetorch2proxy(states_mesh))) + ) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) @@ -608,17 +615,34 @@ def plot_reward_samples( for add_1 in [0, -2 * np.pi, 2 * np.pi]: if not (add_0 == add_1 == 0): extra_samples.append( - np.stack([samples[:limit_n_samples, 0] + add_0, samples[:limit_n_samples, 1] + add_1], axis=1) + np.stack( + [ + samples[:limit_n_samples, 0] + add_0, + samples[:limit_n_samples, 1] + add_1, + ], + axis=1, + ) ) extra_samples = np.concatenate(extra_samples) - ax.scatter(samples[:limit_n_samples, 0], samples[:limit_n_samples, 1], alpha=alpha) + ax.scatter( + samples[:limit_n_samples, 0], samples[:limit_n_samples, 1], alpha=alpha + ) ax.scatter(extra_samples[:, 0], extra_samples[:, 1], alpha=alpha, color="white") ax.grid() # Set tight layout plt.tight_layout() return fig - def plot_kde(self, kde, alpha=0.5, low=-np.pi * 0.5, high=2.5 * np.pi, dpi=150, colorbar=True, **kwargs): + def plot_kde( + self, + kde, + alpha=0.5, + low=-np.pi * 0.5, + high=2.5 * np.pi, + dpi=150, + colorbar=True, + **kwargs, + ): x = np.linspace(0, 2 * np.pi, 101) y = np.linspace(0, 2 * np.pi, 101) xx, yy = np.meshgrid(x, y) @@ -634,10 +658,10 @@ def plot_kde(self, kde, alpha=0.5, low=-np.pi * 0.5, high=2.5 * np.pi, dpi=150, fig.colorbar(h, ax=ax) ax.set_xticks([]) ax.set_yticks([]) - ax.text(0, -0.3, r'$0$', fontsize=15) - ax.text(-0.28, 0, r'$0$', fontsize=15) - ax.text(2*np.pi-0.4, -0.3, r'$2\pi$', fontsize=15) - ax.text(-0.45, 2*np.pi-0.3, r'$2\pi$', fontsize=15) + ax.text(0, -0.3, r"$0$", fontsize=15) + ax.text(-0.28, 0, r"$0$", fontsize=15) + ax.text(2 * np.pi - 0.4, -0.3, r"$2\pi$", fontsize=15) + ax.text(-0.45, 2 * np.pi - 0.3, r"$2\pi$", fontsize=15) for spine in ax.spines.values(): spine.set_visible(False) # Set tight layout diff --git a/gflownet/envs/plane.py b/gflownet/envs/plane.py index 4b7ea4eaa..f0c065db8 100644 --- a/gflownet/envs/plane.py +++ b/gflownet/envs/plane.py @@ -1,16 +1,18 @@ """ Classes to represent hyperplane environments """ -from typing import List, Tuple import itertools +from typing import List, Tuple + import numpy as np import numpy.typing as npt import pandas as pd import torch -from gflownet.envs.base import GFlowNetEnv -from torch.distributions import Categorical, Uniform, Beta +from torch.distributions import Beta, Categorical, Uniform from torchtyping import TensorType +from gflownet.envs.base import GFlowNetEnv + class Plane(GFlowNetEnv): """ @@ -71,13 +73,12 @@ def __init__( # Initialize angles and state attributes self.source = [0.0 for _ in range(self.n_dim)] self.reset() - self.action_space = self.get_actions_space() + self.action_space = self.get_action_space() self.fixed_policy_output = self.get_fixed_policy_output() self.policy_output_dim = len(self.fixed_policy_output) self.policy_input_dim = len(self.state2policy()) # Set up proxy - self.proxy.n_dim = self.n_dim - self.proxy.setup() + self.setup_proxy() # Oracle self.state2oracle = self.state2proxy self.statebatch2oracle = self.statebatch2proxy @@ -115,7 +116,7 @@ def reward_batch(self, states, done): reward[within_plane] = super().reward_batch(states_super, done_super) return reward - def get_actions_space(self): + def get_action_space(self): """ Constructs list with all possible actions. The actions are tuples with two values: (dimension, increment) where dimension indicates the index of the @@ -161,15 +162,15 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): if done is None: done = self.done if done: - return [True for _ in range(len(self.action_space))] + return [True for _ in range(self.action_space_dim)] if ( any([s > self.max_val for s in self.state]) or self.n_actions >= self.max_traj_length ): - mask = [True for _ in range(len(self.action_space))] + mask = [True for _ in range(self.action_space_dim)] mask[-1] = False else: - mask = [False for _ in range(len(self.action_space))] + mask = [False for _ in range(self.action_space_dim)] return mask def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): @@ -183,29 +184,13 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non if done is None: done = self.done if done: - mask = [True for _ in range(len(self.action_space))] + mask = [True for _ in range(self.action_space_dim)] mask[-1] = False else: - mask = [False for _ in range(len(self.action_space))] + mask = [False for _ in range(self.action_space_dim)] # TODO: review: anything to do with max_value? return mask - def true_density(self): - # TODO - # Return pre-computed true density if already stored - if self._true_density is not None: - return self._true_density - # Calculate true density - all_x = self.get_all_terminating_states() - all_oracle = self.state2oracle(all_x) - rewards = self.oracle(all_oracle) - self._true_density = ( - rewards / rewards.sum(), - rewards, - list(map(tuple, all_x)), - ) - return self._true_density - def statebatch2proxy(self, states: List[List] = None) -> npt.NDArray[np.float32]: """ Scales the states into [0, max_val] @@ -428,44 +413,6 @@ def step( else: return self.state, action, False - def make_train_set(self, config): - """ - Constructs a randomly sampled train set. - - Args - ---- - """ - if config is None: - return None - elif "uniform" in config and "n" in config and config.uniform: - samples = self.get_grid_terminating_states(config.n) - energies = self.oracle(self.state2oracle(samples)) - else: - return None - df = pd.DataFrame( - {"samples": [self.state2readable(s) for s in samples], "energies": energies} - ) - return df - - def make_test_set(self, config): - """ - Constructs a test set. - - Args - ---- - """ - if config is None: - return None - elif "uniform" in config and "n" in config and config.uniform: - samples = self.get_grid_terminating_states(config.n) - energies = self.oracle(self.state2oracle(samples)) - else: - return None - df = pd.DataFrame( - {"samples": [self.state2readable(s) for s in samples], "energies": energies} - ) - return df - def get_grid_terminating_states(self, n_states: int) -> List[List]: n_per_dim = int(np.ceil(n_states ** (1 / self.n_dim))) linspaces = [np.linspace(0, self.max_val, n_per_dim) for _ in range(self.n_dim)] diff --git a/gflownet/envs/spacegroup.py b/gflownet/envs/spacegroup.py index 0a8fc2c66..3fcead99f 100644 --- a/gflownet/envs/spacegroup.py +++ b/gflownet/envs/spacegroup.py @@ -7,11 +7,12 @@ import numpy as np import torch from torch import Tensor +from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv from gflownet.utils.crystals.constants import ( - CRYSTAL_SYSTEMS, CRYSTAL_CLASSES, + CRYSTAL_SYSTEMS, POINT_SYMMETRIES, SPACE_GROUPS, ) @@ -22,12 +23,15 @@ class SpaceGroup(GFlowNetEnv): SpaceGroup environment for ionic conductivity. The state space is the combination of three properties: - 1. The crystal system (https://en.wikipedia.org/wiki/Crystal_system#Crystal_system) - (7 options + none) 2. The point symmetry - (https://en.wikipedia.org/wiki/Crystal_system#Crystal_classes) (5 options + none) + 1. The crystal system + See: https://en.wikipedia.org/wiki/Crystal_system#Crystal_system + (7 options + none) + 2. The point symmetry + See: https://en.wikipedia.org/wiki/Crystal_system#Crystal_classes + (5 options + none) 3. The space group - (https://en.wikipedia.org/wiki/Space_group#Table_of_space_groups_in_3_dimensions) - (230 options + none) + See: https://en.wikipedia.org/wiki/Space_group#Table_of_space_groups_in_3_dimensions + (230 options + none) The action space is the choice of property to update and the index within the property (e.g. crystal system 2, point symmetry 4, space group 69, etc.). The @@ -38,7 +42,6 @@ class SpaceGroup(GFlowNetEnv): """ def __init__(self, **kwargs): - super().__init__(**kwargs) self.crystal_systems = CRYSTAL_SYSTEMS self.crystal_classes = CRYSTAL_CLASSES self.point_symmetries = POINT_SYMMETRIES @@ -47,14 +50,15 @@ def __init__(self, **kwargs): self.n_crystal_classes = len(self.crystal_classes) self.n_point_symmetries = len(self.point_symmetries) self.n_space_groups = 230 - # A state is a list of [crystal system index, point symmetry index, space group] self.cs_idx, self.ps_idx, self.sg_idx = 0, 1, 2 + self.eos = (-1, -1) + # Source state: index 0 (empty) for all three properties (crystal system index, + # point symmetry index, space group) self.source = [0 for _ in range(3)] - self.eos = -1 - self.action_space = self.get_actions_space() - self.reset() + # Base class init + super().__init__(**kwargs) - def get_actions_space(self): + def get_action_space(self): """ Constructs list with all possible actions. An action is described by a tuple (property, index), where property is (0: crystal system, @@ -67,16 +71,18 @@ def get_actions_space(self): ): actions_prop = [(prop, idx + 1) for idx in range(n_idx)] actions += actions_prop - actions += [(self.eos, 0)] + actions += [self.eos] return actions - def get_max_traj_len(self): - return 3 - - def get_mask_invalid_actions_forward(self, state=None, done=None): + def get_mask_invalid_actions_forward( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + ) -> List: """ - Returns a vector of length the action space + 1: True if forward action is - invalid given the current state, False otherwise. + Returns a list of length the action space with values: + - True if the forward action is invalid given the current state. + - False otherwise. """ if state is None: state = self.state.copy() @@ -87,7 +93,7 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): # If space group has been selected, only valid action is EOS if state[self.sg_idx] != 0: mask = [True for _ in self.action_space] - mask[self.eos] = False + mask[-1] = False return mask # No constraints if neither crystal system nor point symmetry selected if state[self.cs_idx] == 0 and state[self.ps_idx] == 0: @@ -155,14 +161,50 @@ def state2oracle(self, state: List = None) -> Tensor: """ if state is None: state = self.state - if state[self.sg_idx] == 0: raise ValueError( "The space group must have been set in order to call the oracle" ) - return torch.Tensor(state[self.sg_idx], device=self.device, dtype=self.float) + def statebatch2oracle( + self, states: List[List] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Prepares a batch of states in "GFlowNet format" for the oracle. The input to the + oracle is simply the space group. + + Args + ---- + state : list + A state + + Returns + ---- + oracle_state : Tensor + """ + return self.statetorch2oracle( + torch.Tensor(states, device=self.device, dtype=self.float) + ) + + def statetorch2oracle( + self, states: TensorType["batch", "state_dim"] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Prepares a batch of states in "GFlowNet format" for the oracle. The input to the + oracle is simply the space group. + + Args + ---- + state : list + A state + + Returns + ---- + oracle_state : Tensor + """ + return torch.unsqueeze(states[:, self.sg_idx]) + def state2readable(self, state=None): """ Transforms the state, represented as a list of property indices, into a @@ -224,16 +266,6 @@ def readable2state(self, readable): state = [crystal_system, point_symmetry, space_group] return state - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = self.source.copy() - self.n_actions = 0 - self.done = False - self.id = env_id - return self - def get_parents(self, state=None, done=None, action=None): """ Determines all parents and actions that lead to a state. @@ -261,7 +293,7 @@ def get_parents(self, state=None, done=None, action=None): if done is None: done = self.done if done: - return [state], [(self.eos, 0)] + return [state], [self.eos] else: parents = [] actions = [] @@ -297,7 +329,7 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo Args ---- action : tuple - Action to be executed. See: get_actions_space() + Action to be executed. See: get_action_space() Returns ------- @@ -321,25 +353,27 @@ def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], boo if self.get_mask_invalid_actions_forward()[action_idx]: return self.state, action, False valid = True + self.n_actions += 1 prop, idx = action # Action is not eos - if prop != self.eos: + if action != self.eos: state_next = self.state[:] state_next[prop] = idx # Set crystal system and point symmetry if space group is set - if state_next[self.sg_idx] != 0: - if state_next[self.cs_idx] == 0: - state_next[self.cs_idx] = self.space_groups[ - state_next[self.sg_idx] - ][2] - if state_next[self.ps_idx] == 0: - state_next[self.ps_idx] = self.space_groups[ - state_next[self.sg_idx] - ][3] - self.state = state_next - self.n_actions += 1 + self.state = self._set_constrained_properties(state_next) return self.state, action, valid # Action is eos else: self.done = True return self.state, action, valid + + def get_max_traj_length(self): + return 3 + + def _set_constrained_properties(self, state: List[int]) -> List[int]: + if state[self.sg_idx] != 0: + if state[self.cs_idx] == 0: + state[self.cs_idx] = self.space_groups[state[self.sg_idx]][2] + if state[self.ps_idx] == 0: + state[self.ps_idx] = self.space_groups[state[self.sg_idx]][3] + return state diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index e3aeb1d62..8ccd1594a 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -1,13 +1,15 @@ """ Classes to represent hyper-torus environments """ -from typing import List, Tuple import itertools +from typing import List, Optional, Tuple + import numpy as np import numpy.typing as npt import pandas as pd import torch from torchtyping import TensorType + from gflownet.envs.base import GFlowNetEnv @@ -36,99 +38,80 @@ class Torus(GFlowNetEnv): def __init__( self, - n_dim=2, - n_angles=3, - length_traj=1, - min_step_len=1, - max_step_len=1, + n_dim: int = 2, + n_angles: int = 3, + length_traj: int = 1, + max_increment: int = 1, + max_dim_per_action: int = 1, **kwargs, ): - super().__init__(**kwargs) + assert n_dim > 0 + assert n_angles > 1 + assert length_traj > 0 + assert max_increment > 0 + assert max_dim_per_action == -1 or max_dim_per_action > 0 self.n_dim = n_dim - self.eos = self.n_dim self.n_angles = n_angles self.length_traj = length_traj - # Initialize angles and state attributes - self.source_angles = [0.0 for _ in range(self.n_dim)] - # States are the concatenation of the angle state and number of actions + self.max_increment = max_increment + if max_dim_per_action == -1: + max_dim_per_action = self.n_dim + self.max_dim_per_action = max_dim_per_action + # Source state: position 0 at all dimensions and number of actions 0 + self.source_angles = [0 for _ in range(self.n_dim)] self.source = self.source_angles + [0] - self.reset() - self.source = self.angles.copy() - self.min_step_len = min_step_len - self.max_step_len = max_step_len - self.action_space = self.get_actions_space() - self.fixed_policy_output = self.get_fixed_policy_output() - self.random_policy_output = self.get_fixed_policy_output() - self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) + # End-of-sequence action: (self.max_incremement + 1) in all dimensions + self.eos = tuple([self.max_increment + 1 for _ in range(self.n_dim)]) + # Angle increments in radians self.angle_rad = 2 * np.pi / self.n_angles - # Oracle + # TODO: assess if really needed self.state2oracle = self.state2proxy self.statebatch2oracle = self.statebatch2proxy - # Setup proxy - self.proxy.set_n_dim(self.n_dim) + # Base class init + super().__init__(**kwargs) - def get_actions_space(self): + def get_action_space(self): """ - Constructs list with all possible actions. The actions are tuples with two - values: (dimension, direction) where dimension indicates the index of the - dimension on which the action is to be performed and direction indicates to - increment or decrement with 1 or -1, respectively. The action "keep" is - indicated by (-1, 0). + Constructs list with all possible actions, including eos. An action is + represented by a vector of length n_dim where each index d indicates the + increment/decrement to apply to dimension d of the hyper-torus. A negative + value indicates a decrement. The action "keep" (no increment/decrement of any + dimensions) is valid and is indicated by all zeros. """ - valid_steplens = np.arange(self.min_step_len, self.max_step_len + 1) - dims = [a for a in range(self.n_dim)] - directions = [1, -1] + increments = [el for el in range(-self.max_increment, self.max_increment + 1)] actions = [] - for r in valid_steplens: - actions_r = [el for el in itertools.product(dims, directions, repeat=r)] - actions += actions_r - # Add "keep" action - actions = actions + [(-1, 0)] - # Add "eos" action - actions = actions + [(self.eos, 0)] + for action in itertools.product(increments, repeat=self.n_dim): + if len([el for el in action if el != 0]) <= self.max_dim_per_action: + actions.append(tuple(action)) + actions.append(self.eos) return actions - def get_mask_invalid_actions_forward(self, state=None, done=None): + def get_mask_invalid_actions_forward( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + ) -> List: """ - Returns a vector of length the action space + 1: True if action is invalid - given the current state, False otherwise. + Returns a list of length the action space with values: + - True if the forward action is invalid from the current state. + - False otherwise. + All actions except EOS are valid if the maximum number of actions has not been + reached, and vice versa. """ if state is None: state = self.state.copy() if done is None: done = self.done if done: - return [True for _ in range(len(self.action_space))] + return [True for _ in range(self.action_space_dim)] if state[-1] >= self.length_traj: - mask = [True for _ in range(len(self.action_space))] + mask = [True for _ in range(self.action_space_dim)] mask[-1] = False else: - mask = [False for _ in range(len(self.action_space))] + mask = [False for _ in range(self.action_space_dim)] mask[-1] = True return mask - def true_density(self): - # Return pre-computed true density if already stored - if self._true_density is not None and self._log_z is not None: - return self._true_density, self._log_z - # Calculate true density - x = self.get_all_terminating_states() - rewards = self.reward_batch(x) - self._z = rewards.sum() - self._true_density = ( - rewards / self._z, - rewards, - list(map(tuple, x)), - ) - import ipdb - - ipdb.set_trace() - return self._true_density - - def fit_kde(x, kernel="exponential", bandwidth=0.1): - kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(last_states.numpy()) - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: """ Prepares a batch of states in "GFlowNet format" for the proxy: an array where @@ -145,7 +128,7 @@ def statetorch2proxy( """ return states[:, :-1] * self.angle_rad - # TODO: A circular encoding of the policy state would be better? + # TODO: circular encoding as in htorus def state2policy(self, state=None) -> List: """ Transforms the angles part of the state given as argument (or self.state if @@ -253,17 +236,12 @@ def readable2state(self, readable: str) -> List: n_actions = [int(pair[1])] return angles + n_actions - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = self.source.copy() - self.n_actions = 0 - self.done = False - self.id = env_id - return self - - def get_parents(self, state=None, done=None, action=None): + def get_parents( + self, + state: Optional[List] = None, + done: Optional[bool] = None, + action: Optional[Tuple] = None, + ) -> Tuple[List, List]: """ Determines all parents and actions that lead to state. @@ -288,52 +266,46 @@ def get_parents(self, state=None, done=None, action=None): List of actions that lead to state for each parent in parents """ - def _get_min_actions_to_source(source, ref): - def _get_min_actions_dim(u, v): - return np.min([np.abs(u - v), np.abs(u - (v - self.n_angles))]) - - return np.sum([_get_min_actions_dim(u, v) for u, v in zip(source, ref)]) - if state is None: state = self.state.copy() if done is None: done = self.done if done: - return [state], [(self.eos, 0)] + return [state], [self.eos] # If source state elif state[-1] == 0: return [], [] else: parents = [] actions = [] - for idx, (a_dim, a_dir) in enumerate(self.action_space[:-1]): + for idx, action in enumerate(self.action_space[:-1]): state_p = state.copy() angles_p = state_p[: self.n_dim] n_actions_p = state_p[-1] # Get parent n_actions_p -= 1 - if a_dim != -1: - angles_p[a_dim] -= a_dir + for dim, incr in enumerate(action): + angles_p[dim] -= incr # If negative angle index, restart from the back - if angles_p[a_dim] < 0: - angles_p[a_dim] = self.n_angles + angles_p[a_dim] + if angles_p[dim] < 0: + angles_p[dim] = self.n_angles + angles_p[dim] # If angle index larger than n_angles, restart from 0 - if angles_p[a_dim] >= self.n_angles: - angles_p[a_dim] = angles_p[a_dim] - self.n_angles - if _get_min_actions_to_source(self.source_angles, angles_p) < state[-1]: + if angles_p[dim] >= self.n_angles: + angles_p[dim] = angles_p[dim] - self.n_angles + if self._get_min_actions_to_source(angles_p) < state[-1]: state_p = angles_p + [n_actions_p] parents.append(state_p) - actions.append((a_dim, a_dir)) + actions.append(action) return parents, actions - def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: + def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int], bool]: """ Executes step given an action. Args ---- action : tuple - Action to be executed. See: get_actions_space() + Action to be executed. See: get_action_space() Returns ------- @@ -346,76 +318,59 @@ def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: valid : bool False, if the action is not allowed for the current state. """ - assert action in self.action_space - a_dim, a_dir = action + # If done, return invalid if self.done: return self.state, action, False + # If action not found in action space raise an error + if action not in self.action_space: + raise ValueError( + f"Tried to execute action {action} not present in action space." + ) + else: + action_idx = self.action_space.index(action) + # If action is in invalid mask, return invalid + if self.get_mask_invalid_actions_forward()[action_idx]: + return self.state, action, False # If only possible action is eos, then force eos # If the number of actions is equal to trajectory length elif self.n_actions == self.length_traj: self.done = True self.n_actions += 1 - return self.state, (self.eos, 0), True - # If action is not eos, then perform action - elif a_dim != self.eos: - angles_next = self.angles.copy() - # If action is not "keep" - if a_dim != -1: - angles_next[a_dim] += a_dir + return self.state, self.eos, True + # Perform non-EOS action + else: + angles_next = self.state.copy()[: self.n_dim] + for dim, incr in enumerate(action): + angles_next[dim] += incr # If negative angle index, restart from the back - if angles_next[a_dim] < 0: - angles_next[a_dim] = self.n_angles + angles_next[a_dim] + if angles_next[dim] < 0: + angles_next[dim] = self.n_angles + angles_next[dim] # If angle index larger than n_angles, restart from 0 - if angles_next[a_dim] >= self.n_angles: - angles_next[a_dim] = angles_next[a_dim] - self.n_angles - self.angles = angles_next + if angles_next[dim] >= self.n_angles: + angles_next[dim] = angles_next[dim] - self.n_angles self.n_actions += 1 - self.state = self.angles + [self.n_actions] + self.state = angles_next + [self.n_actions] valid = True return self.state, action, valid - # If action is eos, then it is invalid - else: - return self.state, (self.eos, 0), False - - def make_train_set(self, ntrain, oracle=None, seed=168, output_csv=None): - """ - Constructs a randomly sampled train set. - - Args - ---- - """ - rng = np.random.default_rng(seed) - angles = rng.integers(low=0, high=self.n_angles, size=(ntrain,) + (self.n_dim,)) - n_actions = self.length_traj * np.ones([ntrain, 1], dtype=np.int32) - samples = np.concatenate([angles, n_actions], axis=1) - if oracle: - energies = oracle(self.state2oracle(samples)) - else: - energies = self.oracle(self.state2oracle(samples)) - df_train = pd.DataFrame({"samples": list(samples), "energies": energies}) - if output_csv: - df_train.to_csv(output_csv) - return df_train - - def make_test_set(self, config): - """ - Constructs a test set. - - Args - ---- - """ - if "all" in config and config.all: - samples = self.get_all_terminating_states() - energies = self.oracle(self.state2oracle(samples)) - df_test = pd.DataFrame( - {"samples": [self.state2readable(s) for s in samples], "energies": energies} - ) - return df_test def get_all_terminating_states(self): - all_x = np.int32( - list(itertools.product(*[list(range(self.n_angles))] * self.n_dim)) - ) + all_x = itertools.product(*[list(range(self.n_angles))] * self.n_dim) + all_x_valid = [] + for x in all_x: + if self._get_min_actions_to_source(x) <= self.length_traj: + all_x_valid.append(x) + all_x = np.int32(all_x_valid) n_actions = self.length_traj * np.ones([all_x.shape[0], 1], dtype=np.int32) all_x = np.concatenate([all_x, n_actions], axis=1) return all_x.tolist() + + def fit_kde(x, kernel="exponential", bandwidth=0.1): + kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(last_states.numpy()) + + def _get_min_actions_to_source(self, angles): + def _get_min_actions_dim(u, v): + return np.min([np.abs(u - v), np.abs(u - (v - self.n_angles))]) + + return np.sum( + [_get_min_actions_dim(u, v) for u, v in zip(self.source_angles, angles)] + ) diff --git a/gflownet/envs/torus_rounds.py b/gflownet/envs/torus_rounds.py deleted file mode 100644 index d718c553d..000000000 --- a/gflownet/envs/torus_rounds.py +++ /dev/null @@ -1,393 +0,0 @@ -""" -Classes to represent hyper-torus environments -""" -from typing import List -import itertools -import numpy as np -import pandas as pd -from gflownet.envs.base import GFlowNetEnv - - -class Torus(GFlowNetEnv): - """ - Hyper-torus environment - - Attributes - ---------- - ndim : int - Dimensionality of the torus - - n_angles : int - Number of angles into which each dimension is divided - - max_rounds : int - If larger than one, the action space allows for reaching the initial angle and - restart again up to max_rounds; and the state space contain the round number. - If zero, only one round is allowed, without reaching the initial angle. - """ - - def __init__( - self, - n_dim=2, - n_angles=3, - max_rounds=1, - min_step_len=1, - max_step_len=1, - env_id=None, - reward_beta=1, - reward_norm=1.0, - reward_norm_std_mult=0, - reward_func="boltzmann", - denorm_proxy=False, - energies_stats=None, - proxy=None, - oracle=None, - **kwargs, - ): - super(Torus, self).__init__( - env_id, - reward_beta, - reward_norm, - reward_norm_std_mult, - reward_func, - energies_stats, - denorm_proxy, - proxy, - oracle, - **kwargs, - ) - self.n_dim = n_dim - self.n_angles = n_angles - self.max_rounds = max_rounds - # TODO: do we need to one-hot encode the coordinates and rounds? - self.angles = [0 for _ in range(self.n_dim)] - self.rounds = [0 for _ in range(self.n_dim)] - # States are the concatenation of the angle state and the round state - self.state = self.angles + self.rounds - # TODO: A circular encoding of obs would be better - self.obs_dim = self.n_angles * self.n_dim * 2 - self.min_step_len = min_step_len - self.max_step_len = max_step_len - self.action_space = self.get_actions_space() - self.eos = len(self.action_space) - self.angle_rad = 2 * np.pi / self.n_angles - - def get_actions_space(self): - """ - Constructs list with all possible actions - """ - valid_steplens = np.arange(self.min_step_len, self.max_step_len + 1) - dims = [a for a in range(self.n_dim)] - actions = [] - for r in valid_steplens: - actions_r = [el for el in itertools.product(dims, repeat=r)] - actions += actions_r - return actions - - def get_mask_invalid_actions_forward(self, state=None, done=None): - """ - Returns a vector of length the action space + 1: True if action is invalid - given the current state, False otherwise. - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [True for _ in range(len(self.action_space) + 1)] - mask = [False for _ in range(len(self.action_space) + 1)] - for idx, a in enumerate(self.action_space): - for d in a: - if ( - state[d] + 1 >= self.n_angles - and state[self.n_dim + d] + 1 >= self.max_rounds - ): - mask[idx] = True - break - return mask - - def true_density(self): - # Return pre-computed true density if already stored - if self._true_density is not None: - return self._true_density - # Calculate true density - all_angles = np.int32( - list(itertools.product(*[list(range(self.n_angles))] * self.n_dim)) - ) - all_oracle = self.state2oracle(all_angles) - rewards = self.oracle(all_oracle) - self._true_density = ( - rewards / rewards.sum(), - rewards, - list(map(tuple, all_angles)), - ) - return self._true_density - - def state2proxy(self, state_list): - """ - Prepares a list of states in "GFlowNet format" for the proxy: a list of length - n_dim with an angle in radians. - - Args - ---- - state_list : list of lists - List of states. - """ - # TODO: do we really need to convert back to list? - # TODO: split angles and round? - return (np.array(state_list) * self.angle_rad).tolist() - - def state2oracle(self, state_list): - """ - Prepares a list of states in "GFlowNet format" for the oracles: a list of length - n_dim with values in the range [cell_min, cell_max] for each state. - - Args - ---- - state_list : list of lists - List of states. - """ - return self.state2proxy(state_list) - - def state2policy(self, state=None): - """ - Transforms the state given as argument (or self.state if None) into a - one-hot encoding. The output is a list of len n_angles * n_dim, - where each n-th successive block of length elements is a one-hot encoding of - the position in the n-th dimension. - - Example, n_dim = 2, n_angles = 4: - - State, state: [0, 3, 1, 0] - | a | r | (a = angles, r = rounds) - - state2policy(state): [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0] - | 0 | 3 | 1 | 0 | - """ - if state is None: - state = self.state.copy() - # TODO: do we need float32? - # TODO: do we need one-hot? - obs = np.zeros(self.obs_dim, dtype=np.float32) - # Angles - obs[: self.n_dim * self.n_angles][ - (np.arange(self.n_dim) * self.n_angles + state[: self.n_dim]) - ] = 1 - # Rounds - obs[self.n_dim * self.n_angles :][ - (np.arange(self.n_dim) * self.n_angles + state[self.n_dim :]) - ] = 1 - return obs - - def obs2state(self, obs: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example, n_dim = 2, n_angles = 4: - - obs: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0] - | 0 | 3 | 1 | 0 | - - obs2state(obs): [0, 3, 1, 0] - | a | r | (a = angles, r = rounds) - """ - obs_mat_angles = np.reshape( - obs[: self.n_dim * self.n_angles], (self.n_dim, self.n_angles) - ) - obs_mat_rounds = np.reshape( - obs[self.n_dim * self.n_angles :], (self.n_dim, self.n_angles) - ) - angles = np.where(obs_mat_angles)[1] - rounds = np.where(obs_mat_rounds)[1] - # TODO: do we need to convert to list? - return np.concatenate([angles, rounds]).tolist() - - def state2readable(self, state, alphabet={}): - """ - Converts a state (a list of positions) into a human-readable string - representing a state. - """ - angles = ( - str(state[: self.n_dim]) - .replace("(", "[") - .replace(")", "]") - .replace(",", "") - ) - rounds = ( - str(state[self.n_dim :]) - .replace("(", "[") - .replace(")", "]") - .replace(",", "") - ) - return angles + " | " + rounds - - def readable2state(self, readable, alphabet={}): - """ - Converts a human-readable string representing a state into a state as a list of - positions. - """ - pair = readable.split(" | ") - angles = [int(el) for el in pair[0].strip("[]").split(" ")] - rounds = [int(el) for el in pair[1].strip("[]").split(" ")] - return angles + rounds - - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.angles = [0 for _ in range(self.n_dim)] - self.rounds = [0 for _ in range(self.n_dim)] - self.state = self.angles + self.rounds - self.n_actions = 0 - self.done = False - self.id = env_id - return self - - def get_parents(self, state=None, done=None): - """ - Determines all parents and actions that lead to state. - - Args - ---- - state : list - Representation of a state, as a list of length n_angles where each element is - the position at each dimension. - - action : int - Last action performed - - Returns - ------- - parents : list - List of parents as state2policy(state) - - actions : list - List of actions that lead to state for each parent in parents - """ - # TODO - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [self.state2policy(state)], [self.eos] - else: - parents = [] - actions = [] - for idx, a in enumerate(self.action_space): - state_aux = state.copy() - angles_aux = state_aux[: self.n_dim] - rounds_aux = state_aux[self.n_dim :] - for a_sub in a: - if angles_aux[a_sub] == 0 and rounds_aux[a_sub] > 0: - angles_aux[a_sub] = self.n_angles - 1 - rounds_aux[a_sub] -= 1 - elif angles_aux[a_sub] > 0: - angles_aux[a_sub] -= 1 - else: - break - else: - state_aux = angles_aux + rounds_aux - parents.append(self.state2policy(state_aux)) - actions.append(idx) - return parents, actions - - def step(self, action_idx): - """ - Executes step given an action index. - - Args - ---- - action_idx : int - Index of action in the action space. a == eos indicates "stop action" - - Returns - ------- - self.state : list - The sequence after executing the action - - action_idx : int - Action index - - valid : bool - False, if the action is not allowed for the current state, e.g. stop at the - root state - """ - # If only possible action is eos, then force eos - # All dimensions are at the maximum angle and maximum round - if all([a == self.n_angles - 1 for a in self.angles]) and all( - [r == self.max_rounds - 1 for r in self.rounds] - ): - self.done = True - self.n_actions += 1 - return self.state, self.eos, True - # If action is not eos, then perform action - if action_idx != self.eos: - action = self.action_space[action_idx] - angles_next = self.angles.copy() - rounds_next = self.rounds.copy() - for a in action: - angles_next[a] += 1 - # Increment round and reset angle if necessary - if angles_next[a] == self.n_angles: - angles_next[a] = 0 - rounds_next[a] += 1 - if any([r >= self.max_rounds for r in rounds_next]): - valid = False - else: - self.angles = angles_next - self.rounds = rounds_next - self.state = self.angles + self.rounds - valid = True - self.n_actions += 1 - return self.state, action_idx, valid - # If action is eos, then perform eos - else: - self.done = True - self.n_actions += 1 - return self.state, self.eos, True - - def make_train_set(self, ntrain, oracle=None, seed=168, output_csv=None): - """ - Constructs a randomly sampled train set. - - Args - ---- - """ - rng = np.random.default_rng(seed) - angles = rng.integers(low=0, high=self.n_angles, size=(ntrain,) + (self.n_dim,)) - rounds = rng.integers( - low=0, high=self.max_rounds, size=(ntrain,) + (self.n_dim,) - ) - samples = np.concatenate([angles, rounds], axis=1) - if oracle: - energies = oracle(self.state2oracle(samples)) - else: - energies = self.oracle(self.state2oracle(samples)) - df_train = pd.DataFrame({"samples": list(samples), "energies": energies}) - if output_csv: - df_train.to_csv(output_csv) - return df_train - - def make_test_set(self, config): - """ - Constructs a test set. - - Args - ---- - """ - if "all" in config and config.all: - samples = self.get_all_terminating_states() - energies = self.oracle(self.state2oracle(samples)) - df_test = pd.DataFrame( - {"samples": [self.state2readable(s) for s in samples], "energies": energies} - ) - return df_test - - def get_all_terminating_states(self): - all_x = np.int32( - list( - itertools.product( - *[list(range(self.n_angles))] * self.n_dim - + [list(range(self.max_rounds))] * self.n_dim - ) - ) - ) - return all_x diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 3f9e13ad7..67b728604 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -3,12 +3,12 @@ TODO: - Seeds """ -import sys import copy +import pickle +import sys import time from collections import defaultdict from pathlib import Path -from omegaconf import OmegaConf from typing import List, Tuple import numpy as np @@ -16,12 +16,12 @@ import torch import torch.nn as nn import yaml -import pickle -from torch.distributions import Categorical, Bernoulli -from tqdm import tqdm +from omegaconf import OmegaConf from scipy.special import logsumexp +from torch.distributions import Bernoulli, Categorical +from tqdm import tqdm -from gflownet.envs.base import Buffer +from gflownet.utils.buffer import Buffer from gflownet.utils.common import set_device, set_float_precision, torch2np @@ -409,9 +409,9 @@ def _add_to_batch(batch, envs, actions, valids, train=True): for idx in range(n_empirical): env = envs[idx] env = env.set_state(x_tr[idx].tolist(), done=True) - env.n_actions = env.get_max_traj_len() + env.n_actions = env.get_max_traj_length() envs_offline.append(env) - actions.append((env.eos,)) + actions.append(env.eos) valids.append(True) else: envs_offline = [] @@ -526,7 +526,7 @@ def flowmatch_loss(self, it, batch, loginf=1000): masks_sf, ], ) - parents_a = parents_a.to(int).squeeze() + parents_a_idx = self.env.actions2indices(parents_a) # Compute rewards rewards = self.env.reward_torchbatch(states, done) assert torch.all(rewards[done] > 0) @@ -535,9 +535,9 @@ def flowmatch_loss(self, it, batch, loginf=1000): (states.shape[0], self.env.policy_output_dim), device=self.device, ) - inflow_logits[parents_batch_id, parents_a] = self.forward_policy( + inflow_logits[parents_batch_id, parents_a_idx] = self.forward_policy( self.env.statetorch2policy(parents) - )[torch.arange(parents.shape[0]), parents_a] + )[torch.arange(parents.shape[0]), parents_a_idx] inflow = torch.logsumexp(inflow_logits, dim=1) # Out-flows outflow_logits = self.forward_policy(self.env.statetorch2policy(states)) diff --git a/gflownet/oracle/molecule.py b/gflownet/oracle/molecule.py index efcb637b7..40d82365b 100644 --- a/gflownet/oracle/molecule.py +++ b/gflownet/oracle/molecule.py @@ -1,7 +1,6 @@ import numpy as np import numpy.typing as npt import torch - from xtb.interface import Calculator, Param, XTBException from xtb.libxtb import VERBOSITY_MUTED diff --git a/gflownet/proxy/aptamers.py b/gflownet/proxy/aptamers.py index 7ca019566..338c72347 100644 --- a/gflownet/proxy/aptamers.py +++ b/gflownet/proxy/aptamers.py @@ -1,7 +1,8 @@ -from gflownet.proxy.base import Proxy import numpy as np import numpy.typing as npt +from gflownet.proxy.base import Proxy + class Aptamers(Proxy): """ @@ -13,8 +14,8 @@ def __init__(self, oracle_id, norm): self.type = oracle_id self.norm = norm - def setup(self, max_seq_length, norm=True): - self.max_seq_length = max_seq_length + def setup(self, env=None): + self.max_seq_length = env.max_seq_length def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: """ diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 3142e0f93..61580e80e 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -1,13 +1,15 @@ """ Base class of GFlowNet proxies """ -from abc import abstractmethod +from abc import ABC, abstractmethod + import numpy as np import numpy.typing as npt + from gflownet.utils.common import set_device, set_float_precision -class Proxy: +class Proxy(ABC): """ Generic proxy class """ @@ -20,6 +22,9 @@ def __init__(self, device, float_precision, higher_is_better=False, **kwargs): # Reward2Proxy multiplicative factor (1 or -1) self.higher_is_better = higher_is_better + def setup(self, env=None): + pass + @abstractmethod def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: """ diff --git a/gflownet/proxy/corners.py b/gflownet/proxy/corners.py index 9f617e05d..a577f647e 100644 --- a/gflownet/proxy/corners.py +++ b/gflownet/proxy/corners.py @@ -1,8 +1,9 @@ -from gflownet.proxy.base import Proxy import numpy as np import torch from torchtyping import TensorType +from gflownet.proxy.base import Proxy + class Corners(Proxy): """ @@ -15,7 +16,9 @@ def __init__(self, n_dim=None, mu=None, sigma=None, **kwargs): self.mu = mu self.sigma = sigma - def setup(self): + def setup(self, env=None): + if env: + self.n_dim = env.n_dim if self.sigma and self.mu and self.n_dim: self.mu_vec = self.mu * torch.ones( self.n_dim, device=self.device, dtype=self.float diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index dca87fd37..662eea0ff 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -10,7 +10,6 @@ from gflownet.proxy.base import Proxy from gflownet.utils.common import download_file_if_not_exists - TORCHANI_MODELS = { "ANI1x": torchani.models.ANI1x, "ANI1ccx": torchani.models.ANI1ccx, diff --git a/gflownet/proxy/torus.py b/gflownet/proxy/torus.py index 66003e0ec..4fa521835 100644 --- a/gflownet/proxy/torus.py +++ b/gflownet/proxy/torus.py @@ -1,7 +1,8 @@ -from gflownet.proxy.base import Proxy import torch from torchtyping import TensorType +from gflownet.proxy.base import Proxy + class Torus(Proxy): def __init__(self, normalize, alpha=1.0, beta=1.0, **kwargs): @@ -10,8 +11,9 @@ def __init__(self, normalize, alpha=1.0, beta=1.0, **kwargs): self.alpha = alpha self.beta = beta - def set_n_dim(self, n_dim): - self.n_dim = n_dim + def setup(self, env=None): + if env: + self.n_dim = env.n_dim @property def min(self): diff --git a/gflownet/proxy/uniform.py b/gflownet/proxy/uniform.py index d492072ce..436a4d6f3 100644 --- a/gflownet/proxy/uniform.py +++ b/gflownet/proxy/uniform.py @@ -1,11 +1,12 @@ -from gflownet.proxy.base import Proxy import torch from torchtyping import TensorType +from gflownet.proxy.base import Proxy + class Uniform(Proxy): def __init__(self, **kwargs): - super().__init__() + super().__init__(**kwargs) def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: return -1.0 * torch.ones(states.shape[0]).to(states) @@ -13,4 +14,3 @@ def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batc @property def min(self): return -1.0 - diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py new file mode 100644 index 000000000..637361a2c --- /dev/null +++ b/gflownet/utils/buffer.py @@ -0,0 +1,210 @@ +""" +Buffer class to handle train and test data sets, reply buffer, etc. +""" +import pickle + +import numpy as np +import pandas as pd + + +class Buffer: + """ + Implements the functionality to manage various buffers of data: the records of + training samples, the train and test data sets, a replay buffer for training, etc. + """ + + def __init__( + self, + env, + make_train_test=False, + replay_capacity=0, + output_csv=None, + data_path=None, + train=None, + test=None, + logger=None, + **kwargs, + ): + self.logger = logger + self.env = env + self.replay_capacity = replay_capacity + self.main = pd.DataFrame(columns=["state", "traj", "reward", "energy", "iter"]) + self.replay = pd.DataFrame( + np.empty((self.replay_capacity, 5), dtype=object), + columns=["state", "traj", "reward", "energy", "iter"], + ) + self.replay.reward = pd.to_numeric(self.replay.reward) + self.replay.energy = pd.to_numeric(self.replay.energy) + self.replay.reward = [-1 for _ in range(self.replay_capacity)] + # Define train and test data sets + if train is not None and "type" in train: + self.train_type = train.type + else: + self.train_type = None + self.train, dict_tr = self.make_data_set(train) + if ( + self.train is not None + and "output_csv" in train + and train.output_csv is not None + ): + self.train.to_csv(train.output_csv) + if ( + dict_tr is not None + and "output_pkl" in train + and train.output_pkl is not None + ): + with open(train.output_pkl, "wb") as f: + pickle.dump(dict_tr, f) + self.train_pkl = train.output_pkl + else: + print( + """ + Important: offline trajectories will NOT be sampled. In order to sample + offline trajectories, the train configuration of the buffer should be + complete and feasible and an output pkl file should be defined in + env.buffer.train.output_pkl. + """ + ) + self.train_pkl = None + if test is not None and "type" in test: + self.test_type = test.type + else: + self.train_type = None + self.test, dict_tt = self.make_data_set(test) + if ( + self.test is not None + and "output_csv" in test + and test.output_csv is not None + ): + self.test.to_csv(test.output_csv) + if dict_tt is not None and "output_pkl" in test and test.output_pkl is not None: + with open(test.output_pkl, "wb") as f: + pickle.dump(dict_tt, f) + self.test_pkl = test.output_pkl + else: + print( + """ + Important: test metrics will NOT be computed. In order to compute + test metrics the test configuration of the buffer should be complete and + feasible and an output pkl file should be defined in + env.buffer.test.output_pkl. + """ + ) + self.test_pkl = None + # Compute buffer statistics + if self.train is not None: + ( + self.mean_tr, + self.std_tr, + self.min_tr, + self.max_tr, + self.max_norm_tr, + ) = self.compute_stats(self.train) + if self.test is not None: + self.mean_tt, self.std_tt, self.min_tt, self.max_tt, _ = self.compute_stats( + self.test + ) + + def add( + self, + states, + trajs, + rewards, + energies, + it, + buffer="main", + criterion="greater", + ): + if buffer == "main": + self.main = pd.concat( + [ + self.main, + pd.DataFrame( + { + "state": [self.env.state2readable(s) for s in states], + "traj": [self.env.traj2readable(p) for p in trajs], + "reward": rewards, + "energy": energies, + "iter": it, + } + ), + ], + axis=0, + join="outer", + ) + elif buffer == "replay" and self.replay_capacity > 0: + if criterion == "greater": + self.replay = self._add_greater(states, trajs, rewards, energies, it) + + def _add_greater( + self, + states, + trajs, + rewards, + energies, + it, + ): + rewards_old = self.replay["reward"].values + rewards_new = rewards.copy() + while np.max(rewards_new) > np.min(rewards_old): + idx_new_max = np.argmax(rewards_new) + readable_state = self.env.state2readable(states[idx_new_max]) + if not self.replay["state"].isin([readable_state]).any(): + self.replay.iloc[self.replay.reward.argmin()] = { + "state": self.env.state2readable(states[idx_new_max]), + "traj": self.env.traj2readable(trajs[idx_new_max]), + "reward": rewards[idx_new_max], + "energy": energies[idx_new_max], + "iter": it, + } + rewards_old = self.replay["reward"].values + rewards_new[idx_new_max] = -1 + return self.replay + + def make_data_set(self, config): + """ + Constructs a data set as a DataFrame according to the configuration. + """ + if config is None: + return None, None + elif "path" in config and config.path is not None: + path = self.logger.logdir / Path("data") / config.path + df = pd.read_csv(path, index_col=0) + # TODO: check if state2readable transformation is required. + return df + elif "type" not in config: + return None, None + elif config.type == "all" and hasattr(self.env, "get_all_terminating_states"): + samples = self.env.get_all_terminating_states() + elif ( + config.type == "grid" + and "n" in config + and hasattr(self.env, "get_grid_terminating_states") + ): + samples = self.env.get_grid_terminating_states(config.n) + elif ( + config.type == "uniform" + and "n" in config + and "seed" in config + and hasattr(self.env, "get_uniform_terminating_states") + ): + samples = self.env.get_uniform_terminating_states(config.n, config.seed) + else: + return None, None + energies = self.env.oracle(self.env.statebatch2oracle(samples)).tolist() + df = pd.DataFrame( + { + "samples": [self.env.state2readable(s) for s in samples], + "energies": energies, + } + ) + return df, {"x": samples, "energy": energies} + + def compute_stats(self, data): + mean_data = data["energies"].mean() + std_data = data["energies"].std() + min_data = data["energies"].min() + max_data = data["energies"].max() + data_zscores = (data["energies"] - mean_data) / std_data + max_norm_data = data_zscores.max() + return mean_data, std_data, min_data, max_data, max_norm_data diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index af7be0dd5..ae77b5e50 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -1,9 +1,10 @@ +from collections.abc import MutableMapping +from pathlib import Path + import numpy as np import torch - -from collections.abc import MutableMapping from hydra.utils import get_original_cwd -from pathlib import Path + def set_device(device: str): if device.lower() == "cuda" and torch.cuda.is_available(): @@ -45,12 +46,14 @@ def handle_logdir(): else: print(f"working directory not defined - Ending run...") + def download_file_if_not_exists(path: str, url: str): """ Download a file from google drive if path doestn't exist. url should be in the format: https://drive.google.com/uc?id=FILE_ID """ import gdown + path = Path(path) if not path.is_absolute(): # to avoid storing downloaded files with the logs, prefix is set to the original working dir diff --git a/gflownet/utils/crystals/constants.py b/gflownet/utils/crystals/constants.py index 087b98e6e..05ff5bfd6 100644 --- a/gflownet/utils/crystals/constants.py +++ b/gflownet/utils/crystals/constants.py @@ -246,13 +246,13 @@ # - crystal class indices in the crystal system # - point symmetry indices in the crystal system CRYSTAL_SYSTEMS = { - 1: ['triclinic', [1, 2], [1, 2]], - 2: ['monoclinic', [3, 4, 5], [1, 3, 2]], - 3: ['orthorhombic', [6, 7, 8], [4, 3, 2]], - 4: ['tetragonal', [9, 10, 11, 12, 13, 14, 15], [1, 5, 2, 4, 3]], - 5: ['trigonal', [16, 17, 18, 19, 20], [1, 2, 4, 3]], - 6: ['hexagonal', [21, 22, 23, 24, 25, 26, 27], [1, 5, 2, 4, 3]], - 7: ['cubic', [28, 29, 30, 31, 32], [4, 2, 5]], + 1: ["triclinic", [1, 2], [1, 2]], + 2: ["monoclinic", [3, 4, 5], [1, 3, 2]], + 3: ["orthorhombic", [6, 7, 8], [4, 3, 2]], + 4: ["tetragonal", [9, 10, 11, 12, 13, 14, 15], [1, 5, 2, 4, 3]], + 5: ["trigonal", [16, 17, 18, 19, 20], [1, 2, 4, 3]], + 6: ["hexagonal", [21, 22, 23, 24, 25, 26, 27], [1, 5, 2, 4, 3]], + 7: ["cubic", [28, 29, 30, 31, 32], [4, 2, 5]], } CRYSTAL_SYSTEMS_MINIMAL = { 1: "triclinic", @@ -591,11 +591,15 @@ # - crystal system indices with the point symmetry # - crystal class indices with the point symmetry POINT_SYMMETRIES = { - 1: ['enantiomorphic-polar', [1, 2, 4, 5, 6], [1, 3, 9, 16, 21]], - 2: ['centrosymmetric', [1, 2, 3, 4, 5, 6, 7], [2, 5, 8, 11, 15, 17, 20, 23, 27, 29, 32]], - 3: ['polar', [2, 3, 4, 5, 6], [4, 7, 13, 19, 25]], - 4: ['enantiomorphic', [3, 4, 5, 6, 7], [6, 12, 18, 24, 28, 30]], - 5: ['non-centrosymmetric', [4, 6, 7], [10, 14, 22, 26, 31]], + 1: ["enantiomorphic-polar", [1, 2, 4, 5, 6], [1, 3, 9, 16, 21]], + 2: [ + "centrosymmetric", + [1, 2, 3, 4, 5, 6, 7], + [2, 5, 8, 11, 15, 17, 20, 23, 27, 29, 32], + ], + 3: ["polar", [2, 3, 4, 5, 6], [4, 7, 13, 19, 25]], + 4: ["enantiomorphic", [3, 4, 5, 6, 7], [6, 12, 18, 24, 28, 30]], + 5: ["non-centrosymmetric", [4, 6, 7], [10, 14, 22, 26, 31]], } POINT_SYMMETRIES_MINIMAL = { 1: "enantiomorphic-polar", @@ -613,236 +617,236 @@ # - crystal system index # - point symmetry index SPACE_GROUPS = { - 1: ['1', 1, 1, 1], - 2: ['-1', 2, 1, 2], - 3: ['2', 3, 2, 1], - 4: ['2', 3, 2, 1], - 5: ['2', 3, 2, 1], - 6: ['m', 4, 2, 3], - 7: ['m', 4, 2, 3], - 8: ['m', 4, 2, 3], - 9: ['m', 4, 2, 3], - 10: ['2/m', 5, 2, 2], - 11: ['2/m', 5, 2, 2], - 12: ['2/m', 5, 2, 2], - 13: ['2/m', 5, 2, 2], - 14: ['2/m', 5, 2, 2], - 15: ['2/m', 5, 2, 2], - 16: ['222', 6, 3, 4], - 17: ['222', 6, 3, 4], - 18: ['222', 6, 3, 4], - 19: ['222', 6, 3, 4], - 20: ['222', 6, 3, 4], - 21: ['222', 6, 3, 4], - 22: ['222', 6, 3, 4], - 23: ['222', 6, 3, 4], - 24: ['222', 6, 3, 4], - 25: ['mm2', 7, 3, 3], - 26: ['mm2', 7, 3, 3], - 27: ['mm2', 7, 3, 3], - 28: ['mm2', 7, 3, 3], - 29: ['mm2', 7, 3, 3], - 30: ['mm2', 7, 3, 3], - 31: ['mm2', 7, 3, 3], - 32: ['mm2', 7, 3, 3], - 33: ['mm2', 7, 3, 3], - 34: ['mm2', 7, 3, 3], - 35: ['mm2', 7, 3, 3], - 36: ['mm2', 7, 3, 3], - 37: ['mm2', 7, 3, 3], - 38: ['mm2', 7, 3, 3], - 39: ['mm2', 7, 3, 3], - 40: ['mm2', 7, 3, 3], - 41: ['mm2', 7, 3, 3], - 42: ['mm2', 7, 3, 3], - 43: ['mm2', 7, 3, 3], - 44: ['mm2', 7, 3, 3], - 45: ['mm2', 7, 3, 3], - 46: ['mm2', 7, 3, 3], - 47: ['mmm', 8, 3, 2], - 48: ['mmm', 8, 3, 2], - 49: ['mmm', 8, 3, 2], - 50: ['mmm', 8, 3, 2], - 51: ['mmm', 8, 3, 2], - 52: ['mmm', 8, 3, 2], - 53: ['mmm', 8, 3, 2], - 54: ['mmm', 8, 3, 2], - 55: ['mmm', 8, 3, 2], - 56: ['mmm', 8, 3, 2], - 57: ['mmm', 8, 3, 2], - 58: ['mmm', 8, 3, 2], - 59: ['mmm', 8, 3, 2], - 60: ['mmm', 8, 3, 2], - 61: ['mmm', 8, 3, 2], - 62: ['mmm', 8, 3, 2], - 63: ['mmm', 8, 3, 2], - 64: ['mmm', 8, 3, 2], - 65: ['mmm', 8, 3, 2], - 66: ['mmm', 8, 3, 2], - 67: ['mmm', 8, 3, 2], - 68: ['mmm', 8, 3, 2], - 69: ['mmm', 8, 3, 2], - 70: ['mmm', 8, 3, 2], - 71: ['mmm', 8, 3, 2], - 72: ['mmm', 8, 3, 2], - 73: ['mmm', 8, 3, 2], - 74: ['mmm', 8, 3, 2], - 75: ['4', 9, 4, 1], - 76: ['4', 9, 4, 1], - 77: ['4', 9, 4, 1], - 78: ['4', 9, 4, 1], - 79: ['4', 9, 4, 1], - 80: ['4', 9, 4, 1], - 81: ['-4', 10, 4, 5], - 82: ['-4', 10, 4, 5], - 83: ['4/m', 11, 4, 2], - 84: ['4/m', 11, 4, 2], - 85: ['4/m', 11, 4, 2], - 86: ['4/m', 11, 4, 2], - 87: ['4/m', 11, 4, 2], - 88: ['4/m', 11, 4, 2], - 89: ['422', 12, 4, 4], - 90: ['422', 12, 4, 4], - 91: ['422', 12, 4, 4], - 92: ['422', 12, 4, 4], - 93: ['422', 12, 4, 4], - 94: ['422', 12, 4, 4], - 95: ['422', 12, 4, 4], - 96: ['422', 12, 4, 4], - 97: ['422', 12, 4, 4], - 98: ['422', 12, 4, 4], - 99: ['4mm', 13, 4, 3], - 100: ['4mm', 13, 4, 3], - 101: ['4mm', 13, 4, 3], - 102: ['4mm', 13, 4, 3], - 103: ['4mm', 13, 4, 3], - 104: ['4mm', 13, 4, 3], - 105: ['4mm', 13, 4, 3], - 106: ['4mm', 13, 4, 3], - 107: ['4mm', 13, 4, 3], - 108: ['4mm', 13, 4, 3], - 109: ['4mm', 13, 4, 3], - 110: ['4mm', 13, 4, 3], - 111: ['-42m', 14, 4, 5], - 112: ['-42m', 14, 4, 5], - 113: ['-42m', 14, 4, 5], - 114: ['-42m', 14, 4, 5], - 115: ['-4m2', 14, 4, 5], - 116: ['-4m2', 14, 4, 5], - 117: ['-4m2', 14, 4, 5], - 118: ['-4m2', 14, 4, 5], - 119: ['-4m2', 14, 4, 5], - 120: ['-4m2', 14, 4, 5], - 121: ['-42m', 14, 4, 5], - 122: ['-42m', 14, 4, 5], - 123: ['4/mmm', 15, 4, 2], - 124: ['4/mmm', 15, 4, 2], - 125: ['4/mmm', 15, 4, 2], - 126: ['4/mmm', 15, 4, 2], - 127: ['4/mmm', 15, 4, 2], - 128: ['4/mmm', 15, 4, 2], - 129: ['4/mmm', 15, 4, 2], - 130: ['4/mmm', 15, 4, 2], - 131: ['4/mmm', 15, 4, 2], - 132: ['4/mmm', 15, 4, 2], - 133: ['4/mmm', 15, 4, 2], - 134: ['4/mmm', 15, 4, 2], - 135: ['4/mmm', 15, 4, 2], - 136: ['4/mmm', 15, 4, 2], - 137: ['4/mmm', 15, 4, 2], - 138: ['4/mmm', 15, 4, 2], - 139: ['4/mmm', 15, 4, 2], - 140: ['4/mmm', 15, 4, 2], - 141: ['4/mmm', 15, 4, 2], - 142: ['4/mmm', 15, 4, 2], - 143: ['3', 16, 5, 1], - 144: ['3', 16, 5, 1], - 145: ['3', 16, 5, 1], - 146: ['3', 16, 5, 1], - 147: ['-3', 17, 5, 2], - 148: ['-3', 17, 5, 2], - 149: ['312', 18, 5, 4], - 150: ['321', 18, 5, 4], - 151: ['312', 18, 5, 4], - 152: ['321', 18, 5, 4], - 153: ['312', 18, 5, 4], - 154: ['321', 18, 5, 4], - 155: ['32', 18, 5, 4], - 156: ['3m1', 19, 5, 3], - 157: ['31m', 19, 5, 3], - 158: ['3m1', 19, 5, 3], - 159: ['31m', 19, 5, 3], - 160: ['3m', 19, 5, 3], - 161: ['3m', 19, 5, 3], - 162: ['-31m', 20, 5, 2], - 163: ['-31m', 20, 5, 2], - 164: ['-3m1', 20, 5, 2], - 165: ['-3m1', 20, 5, 2], - 166: ['-3m', 20, 5, 2], - 167: ['-3m', 20, 5, 2], - 168: ['6', 21, 6, 1], - 169: ['6', 21, 6, 1], - 170: ['6', 21, 6, 1], - 171: ['6', 21, 6, 1], - 172: ['6', 21, 6, 1], - 173: ['6', 21, 6, 1], - 174: ['-6', 22, 6, 5], - 175: ['6/m', 23, 6, 2], - 176: ['6/m', 23, 6, 2], - 177: ['622', 24, 6, 4], - 178: ['622', 24, 6, 4], - 179: ['622', 24, 6, 4], - 180: ['622', 24, 6, 4], - 181: ['622', 24, 6, 4], - 182: ['622', 24, 6, 4], - 183: ['6mm', 25, 6, 3], - 184: ['6mm', 25, 6, 3], - 185: ['6mm', 25, 6, 3], - 186: ['6mm', 25, 6, 3], - 187: ['-6m2', 26, 6, 5], - 188: ['-6m2', 26, 6, 5], - 189: ['-62m', 26, 6, 5], - 190: ['-62m', 26, 6, 5], - 191: ['6/mmm', 27, 6, 2], - 192: ['6/mmm', 27, 6, 2], - 193: ['6/mmm', 27, 6, 2], - 194: ['6/mmm', 27, 6, 2], - 195: ['23', 28, 7, 4], - 196: ['23', 28, 7, 4], - 197: ['23', 28, 7, 4], - 198: ['23', 28, 7, 4], - 199: ['23', 28, 7, 4], - 200: ['m-3', 29, 7, 2], - 201: ['m-3', 29, 7, 2], - 202: ['m-3', 29, 7, 2], - 203: ['m-3', 29, 7, 2], - 204: ['m-3', 29, 7, 2], - 205: ['m-3', 29, 7, 2], - 206: ['m-3', 29, 7, 2], - 207: ['432', 30, 7, 4], - 208: ['432', 30, 7, 4], - 209: ['432', 30, 7, 4], - 210: ['432', 30, 7, 4], - 211: ['432', 30, 7, 4], - 212: ['432', 30, 7, 4], - 213: ['432', 30, 7, 4], - 214: ['432', 30, 7, 4], - 215: ['-43m', 31, 7, 5], - 216: ['-43m', 31, 7, 5], - 217: ['-43m', 31, 7, 5], - 218: ['-43m', 31, 7, 5], - 219: ['-43m', 31, 7, 5], - 220: ['-43m', 31, 7, 5], - 221: ['m-3m', 32, 7, 2], - 222: ['m-3m', 32, 7, 2], - 223: ['m-3m', 32, 7, 2], - 224: ['m-3m', 32, 7, 2], - 225: ['m-3m', 32, 7, 2], - 226: ['m-3m', 32, 7, 2], - 227: ['m-3m', 32, 7, 2], - 228: ['m-3m', 32, 7, 2], - 229: ['m-3m', 32, 7, 2], - 230: ['m-3m', 32, 7, 2], + 1: ["1", 1, 1, 1], + 2: ["-1", 2, 1, 2], + 3: ["2", 3, 2, 1], + 4: ["2", 3, 2, 1], + 5: ["2", 3, 2, 1], + 6: ["m", 4, 2, 3], + 7: ["m", 4, 2, 3], + 8: ["m", 4, 2, 3], + 9: ["m", 4, 2, 3], + 10: ["2/m", 5, 2, 2], + 11: ["2/m", 5, 2, 2], + 12: ["2/m", 5, 2, 2], + 13: ["2/m", 5, 2, 2], + 14: ["2/m", 5, 2, 2], + 15: ["2/m", 5, 2, 2], + 16: ["222", 6, 3, 4], + 17: ["222", 6, 3, 4], + 18: ["222", 6, 3, 4], + 19: ["222", 6, 3, 4], + 20: ["222", 6, 3, 4], + 21: ["222", 6, 3, 4], + 22: ["222", 6, 3, 4], + 23: ["222", 6, 3, 4], + 24: ["222", 6, 3, 4], + 25: ["mm2", 7, 3, 3], + 26: ["mm2", 7, 3, 3], + 27: ["mm2", 7, 3, 3], + 28: ["mm2", 7, 3, 3], + 29: ["mm2", 7, 3, 3], + 30: ["mm2", 7, 3, 3], + 31: ["mm2", 7, 3, 3], + 32: ["mm2", 7, 3, 3], + 33: ["mm2", 7, 3, 3], + 34: ["mm2", 7, 3, 3], + 35: ["mm2", 7, 3, 3], + 36: ["mm2", 7, 3, 3], + 37: ["mm2", 7, 3, 3], + 38: ["mm2", 7, 3, 3], + 39: ["mm2", 7, 3, 3], + 40: ["mm2", 7, 3, 3], + 41: ["mm2", 7, 3, 3], + 42: ["mm2", 7, 3, 3], + 43: ["mm2", 7, 3, 3], + 44: ["mm2", 7, 3, 3], + 45: ["mm2", 7, 3, 3], + 46: ["mm2", 7, 3, 3], + 47: ["mmm", 8, 3, 2], + 48: ["mmm", 8, 3, 2], + 49: ["mmm", 8, 3, 2], + 50: ["mmm", 8, 3, 2], + 51: ["mmm", 8, 3, 2], + 52: ["mmm", 8, 3, 2], + 53: ["mmm", 8, 3, 2], + 54: ["mmm", 8, 3, 2], + 55: ["mmm", 8, 3, 2], + 56: ["mmm", 8, 3, 2], + 57: ["mmm", 8, 3, 2], + 58: ["mmm", 8, 3, 2], + 59: ["mmm", 8, 3, 2], + 60: ["mmm", 8, 3, 2], + 61: ["mmm", 8, 3, 2], + 62: ["mmm", 8, 3, 2], + 63: ["mmm", 8, 3, 2], + 64: ["mmm", 8, 3, 2], + 65: ["mmm", 8, 3, 2], + 66: ["mmm", 8, 3, 2], + 67: ["mmm", 8, 3, 2], + 68: ["mmm", 8, 3, 2], + 69: ["mmm", 8, 3, 2], + 70: ["mmm", 8, 3, 2], + 71: ["mmm", 8, 3, 2], + 72: ["mmm", 8, 3, 2], + 73: ["mmm", 8, 3, 2], + 74: ["mmm", 8, 3, 2], + 75: ["4", 9, 4, 1], + 76: ["4", 9, 4, 1], + 77: ["4", 9, 4, 1], + 78: ["4", 9, 4, 1], + 79: ["4", 9, 4, 1], + 80: ["4", 9, 4, 1], + 81: ["-4", 10, 4, 5], + 82: ["-4", 10, 4, 5], + 83: ["4/m", 11, 4, 2], + 84: ["4/m", 11, 4, 2], + 85: ["4/m", 11, 4, 2], + 86: ["4/m", 11, 4, 2], + 87: ["4/m", 11, 4, 2], + 88: ["4/m", 11, 4, 2], + 89: ["422", 12, 4, 4], + 90: ["422", 12, 4, 4], + 91: ["422", 12, 4, 4], + 92: ["422", 12, 4, 4], + 93: ["422", 12, 4, 4], + 94: ["422", 12, 4, 4], + 95: ["422", 12, 4, 4], + 96: ["422", 12, 4, 4], + 97: ["422", 12, 4, 4], + 98: ["422", 12, 4, 4], + 99: ["4mm", 13, 4, 3], + 100: ["4mm", 13, 4, 3], + 101: ["4mm", 13, 4, 3], + 102: ["4mm", 13, 4, 3], + 103: ["4mm", 13, 4, 3], + 104: ["4mm", 13, 4, 3], + 105: ["4mm", 13, 4, 3], + 106: ["4mm", 13, 4, 3], + 107: ["4mm", 13, 4, 3], + 108: ["4mm", 13, 4, 3], + 109: ["4mm", 13, 4, 3], + 110: ["4mm", 13, 4, 3], + 111: ["-42m", 14, 4, 5], + 112: ["-42m", 14, 4, 5], + 113: ["-42m", 14, 4, 5], + 114: ["-42m", 14, 4, 5], + 115: ["-4m2", 14, 4, 5], + 116: ["-4m2", 14, 4, 5], + 117: ["-4m2", 14, 4, 5], + 118: ["-4m2", 14, 4, 5], + 119: ["-4m2", 14, 4, 5], + 120: ["-4m2", 14, 4, 5], + 121: ["-42m", 14, 4, 5], + 122: ["-42m", 14, 4, 5], + 123: ["4/mmm", 15, 4, 2], + 124: ["4/mmm", 15, 4, 2], + 125: ["4/mmm", 15, 4, 2], + 126: ["4/mmm", 15, 4, 2], + 127: ["4/mmm", 15, 4, 2], + 128: ["4/mmm", 15, 4, 2], + 129: ["4/mmm", 15, 4, 2], + 130: ["4/mmm", 15, 4, 2], + 131: ["4/mmm", 15, 4, 2], + 132: ["4/mmm", 15, 4, 2], + 133: ["4/mmm", 15, 4, 2], + 134: ["4/mmm", 15, 4, 2], + 135: ["4/mmm", 15, 4, 2], + 136: ["4/mmm", 15, 4, 2], + 137: ["4/mmm", 15, 4, 2], + 138: ["4/mmm", 15, 4, 2], + 139: ["4/mmm", 15, 4, 2], + 140: ["4/mmm", 15, 4, 2], + 141: ["4/mmm", 15, 4, 2], + 142: ["4/mmm", 15, 4, 2], + 143: ["3", 16, 5, 1], + 144: ["3", 16, 5, 1], + 145: ["3", 16, 5, 1], + 146: ["3", 16, 5, 1], + 147: ["-3", 17, 5, 2], + 148: ["-3", 17, 5, 2], + 149: ["312", 18, 5, 4], + 150: ["321", 18, 5, 4], + 151: ["312", 18, 5, 4], + 152: ["321", 18, 5, 4], + 153: ["312", 18, 5, 4], + 154: ["321", 18, 5, 4], + 155: ["32", 18, 5, 4], + 156: ["3m1", 19, 5, 3], + 157: ["31m", 19, 5, 3], + 158: ["3m1", 19, 5, 3], + 159: ["31m", 19, 5, 3], + 160: ["3m", 19, 5, 3], + 161: ["3m", 19, 5, 3], + 162: ["-31m", 20, 5, 2], + 163: ["-31m", 20, 5, 2], + 164: ["-3m1", 20, 5, 2], + 165: ["-3m1", 20, 5, 2], + 166: ["-3m", 20, 5, 2], + 167: ["-3m", 20, 5, 2], + 168: ["6", 21, 6, 1], + 169: ["6", 21, 6, 1], + 170: ["6", 21, 6, 1], + 171: ["6", 21, 6, 1], + 172: ["6", 21, 6, 1], + 173: ["6", 21, 6, 1], + 174: ["-6", 22, 6, 5], + 175: ["6/m", 23, 6, 2], + 176: ["6/m", 23, 6, 2], + 177: ["622", 24, 6, 4], + 178: ["622", 24, 6, 4], + 179: ["622", 24, 6, 4], + 180: ["622", 24, 6, 4], + 181: ["622", 24, 6, 4], + 182: ["622", 24, 6, 4], + 183: ["6mm", 25, 6, 3], + 184: ["6mm", 25, 6, 3], + 185: ["6mm", 25, 6, 3], + 186: ["6mm", 25, 6, 3], + 187: ["-6m2", 26, 6, 5], + 188: ["-6m2", 26, 6, 5], + 189: ["-62m", 26, 6, 5], + 190: ["-62m", 26, 6, 5], + 191: ["6/mmm", 27, 6, 2], + 192: ["6/mmm", 27, 6, 2], + 193: ["6/mmm", 27, 6, 2], + 194: ["6/mmm", 27, 6, 2], + 195: ["23", 28, 7, 4], + 196: ["23", 28, 7, 4], + 197: ["23", 28, 7, 4], + 198: ["23", 28, 7, 4], + 199: ["23", 28, 7, 4], + 200: ["m-3", 29, 7, 2], + 201: ["m-3", 29, 7, 2], + 202: ["m-3", 29, 7, 2], + 203: ["m-3", 29, 7, 2], + 204: ["m-3", 29, 7, 2], + 205: ["m-3", 29, 7, 2], + 206: ["m-3", 29, 7, 2], + 207: ["432", 30, 7, 4], + 208: ["432", 30, 7, 4], + 209: ["432", 30, 7, 4], + 210: ["432", 30, 7, 4], + 211: ["432", 30, 7, 4], + 212: ["432", 30, 7, 4], + 213: ["432", 30, 7, 4], + 214: ["432", 30, 7, 4], + 215: ["-43m", 31, 7, 5], + 216: ["-43m", 31, 7, 5], + 217: ["-43m", 31, 7, 5], + 218: ["-43m", 31, 7, 5], + 219: ["-43m", 31, 7, 5], + 220: ["-43m", 31, 7, 5], + 221: ["m-3m", 32, 7, 2], + 222: ["m-3m", 32, 7, 2], + 223: ["m-3m", 32, 7, 2], + 224: ["m-3m", 32, 7, 2], + 225: ["m-3m", 32, 7, 2], + 226: ["m-3m", 32, 7, 2], + 227: ["m-3m", 32, 7, 2], + 228: ["m-3m", 32, 7, 2], + 229: ["m-3m", 32, 7, 2], + 230: ["m-3m", 32, 7, 2], } SPACE_GROUPS_MINIMAL = { 1: "1", diff --git a/gflownet/utils/legacy.py b/gflownet/utils/legacy.py index f923eb642..f015a30e9 100644 --- a/gflownet/utils/legacy.py +++ b/gflownet/utils/legacy.py @@ -1,12 +1,12 @@ """import statement""" -from argparse import Namespace -import yaml -from pathlib import Path -import numpy as np -import matplotlib.pyplot as plt import os import time +from argparse import Namespace +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +import yaml """ This is a general utilities file for the active learning pipeline diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index fc82852c8..1729a1e36 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -1,10 +1,11 @@ from datetime import datetime +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np import torch -from pathlib import Path from numpy import array from omegaconf import OmegaConf -import matplotlib.pyplot as plt class Logger: @@ -188,7 +189,7 @@ def log_train( states_term: list, batch_size: int, logz, - learning_rates: list, # [lr, lr_logZ] + learning_rates: list, # [lr, lr_logZ] step: int, use_context: bool, ): @@ -212,7 +213,7 @@ def log_train( "batch_size", "logZ", "lr", - "lr_logZ" + "lr_logZ", ], [ np.mean(rewards), @@ -224,7 +225,7 @@ def log_train( batch_size, logz, learning_rates[0], - learning_rates[1] + learning_rates[1], ], ) ) diff --git a/gflownet/utils/metrics.py b/gflownet/utils/metrics.py index b9a9ef87c..99a6dad62 100644 --- a/gflownet/utils/metrics.py +++ b/gflownet/utils/metrics.py @@ -1,6 +1,5 @@ -import torch import numpy as np - +import torch from sklearn.neighbors import KernelDensity diff --git a/gflownet/utils/molecule/atom_positions_dataset.py b/gflownet/utils/molecule/atom_positions_dataset.py index 0ed1eb171..0b66f4363 100644 --- a/gflownet/utils/molecule/atom_positions_dataset.py +++ b/gflownet/utils/molecule/atom_positions_dataset.py @@ -2,6 +2,7 @@ from gflownet.utils.common import download_file_if_not_exists + class AtomPositionsDataset: def __init__(self, path_to_data, url_to_data): path_to_data = download_file_if_not_exists(path_to_data, url_to_data) diff --git a/gflownet/utils/molecule/conformer.py b/gflownet/utils/molecule/conformer.py index 1fffc9fe8..689c73db5 100644 --- a/gflownet/utils/molecule/conformer.py +++ b/gflownet/utils/molecule/conformer.py @@ -1,17 +1,15 @@ -import numpy as np -import torch - from collections import defaultdict from copy import deepcopy + +import numpy as np +import torch from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit.Chem import rdMolTransforms -from rdkit.Chem import TorsionFingerprints +from rdkit.Chem import AllChem, TorsionFingerprints, rdMolTransforms from rdkit.Geometry.rdGeometry import Point3D from gflownet.utils.molecule import constants -from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.conformer_base import ConformerBase +from gflownet.utils.molecule.featurizer import MolDGLFeaturizer class Conformer(ConformerBase): @@ -35,7 +33,8 @@ def dgl_graph(self): def set_atom_positions_dgl(self, atom_positions): """Set atom positions of the self.dgl_graph to the input atom_positions values - :param atom_positions: 2d numpy array of shape [num atoms, 3] with new atom positions""" + :param atom_positions: 2d numpy array of shape [num atoms, 3] with new atom positions + """ self._dgl_graph.ndata[constants.atom_position_name] = torch.Tensor( atom_positions ) @@ -73,6 +72,7 @@ def get_ta_index_in_dgl_graph(self, torsion_angle): if __name__ == "__main__": from tabulate import tabulate + from gflownet.utils.molecule.conformer_base import get_all_torsion_angles rmol = Chem.MolFromSmiles(constants.ad_smiles) diff --git a/gflownet/utils/molecule/conformer_base.py b/gflownet/utils/molecule/conformer_base.py index 7c9e0c223..7282010e4 100644 --- a/gflownet/utils/molecule/conformer_base.py +++ b/gflownet/utils/molecule/conformer_base.py @@ -1,9 +1,6 @@ import numpy as np - from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit.Chem import rdMolTransforms -from rdkit.Chem import TorsionFingerprints +from rdkit.Chem import AllChem, TorsionFingerprints, rdMolTransforms from rdkit.Geometry.rdGeometry import Point3D from gflownet.utils.molecule import constants @@ -73,7 +70,8 @@ def embed_mol_and_get_conformer(self, mol, extra_opt=False): """Embed RDkit mol with a conformer and return the RDKit conformer object (which is synchronized with the RDKit molecule object) :param mol: rdkit.Chem.rdchem.Mol object defining the molecule - :param extre_opt: bool, if True, an additional optimisation of the conformer will be performed""" + :param extre_opt: bool, if True, an additional optimisation of the conformer will be performed + """ AllChem.EmbedMolecule(mol) if extra_opt: AllChem.MMFFOptimizeMolecule(mol, confId=0, maxIters=1000) @@ -81,7 +79,8 @@ def embed_mol_and_get_conformer(self, mol, extra_opt=False): def set_atom_positions(self, atom_positions): """Set atom positions of the self.rdk_conf to the input atom_positions values - :param atom_positions: 2d numpy array of shape [num atoms, 3] with new atom positions""" + :param atom_positions: 2d numpy array of shape [num atoms, 3] with new atom positions + """ for idx, pos in enumerate(atom_positions): self.rdk_conf.SetAtomPosition(idx, Point3D(*pos)) diff --git a/gflownet/utils/oracle.py b/gflownet/utils/oracle.py index 05da9a440..c4713745a 100644 --- a/gflownet/utils/oracle.py +++ b/gflownet/utils/oracle.py @@ -1,10 +1,10 @@ """import statements""" +import sys + from omegaconf import ListConfig +from potts_utils import load_potts_model, potts_energy from seqfold import dg, fold from utils import * -from potts_utils import load_potts_model -from potts_utils import potts_energy -import sys try: # we don't always install these on every platform from nupack import * @@ -14,8 +14,8 @@ ) pass try: + from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel from bbdob.utils import idx2one_hot - from bbdob import OneMax, TwoMin, FourPeaks, DeceptiveTrap, NKLandscape, WModel except: print( "COULD NOT IMPORT BB-DOB ON THIS DEVICE - proceeding, but will crash with BB-DOB oracle selected" @@ -401,7 +401,6 @@ def PottsEnergy(self, queries): return energies def PottsEnergyNew(self, sequences): - # Load the potts model J, h = load_potts_model(435) diff --git a/gflownet/utils/potts_utils.py b/gflownet/utils/potts_utils.py index 50a2aacae..556ad5826 100644 --- a/gflownet/utils/potts_utils.py +++ b/gflownet/utils/potts_utils.py @@ -1,5 +1,5 @@ -import scipy.io as sco import numpy as np +import scipy.io as sco def load_potts_model(num_of_elements): diff --git a/main.py b/main.py index 8f018a732..5c093dd03 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,14 @@ """ Runnable script with hydra capabilities """ -import sys import os import random +import sys + import hydra import pandas as pd import yaml -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf @hydra.main(config_path="./config", config_name="main", version_base="1.1") @@ -61,8 +62,8 @@ def main(config): def set_seeds(seed): - import torch import numpy as np + import torch torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False diff --git a/playground/botorch/mes_exact_deepKernel.py b/playground/botorch/mes_exact_deepKernel.py index e2cb6e7be..b77bb2e89 100644 --- a/playground/botorch/mes_exact_deepKernel.py +++ b/playground/botorch/mes_exact_deepKernel.py @@ -3,18 +3,17 @@ """ import math +import os +import urllib.request +from math import floor -# import tqdm -import torch import gpytorch -from tqdm.notebook import tqdm -import urllib.request -import os -from scipy.io import loadmat -from math import floor +# import tqdm +import torch from botorch.test_functions import Hartmann - +from scipy.io import loadmat +from tqdm.notebook import tqdm """ Initialise the dataset @@ -136,8 +135,8 @@ def train(): """ from botorch.models import SingleTaskGP -from gpytorch.distributions import MultivariateNormal from botorch.posteriors import GPyTorchPosterior +from gpytorch.distributions import MultivariateNormal class myGPModel(SingleTaskGP): diff --git a/playground/botorch/mes_gp.py b/playground/botorch/mes_gp.py index 6307c0e82..b51df0ce6 100644 --- a/playground/botorch/mes_gp.py +++ b/playground/botorch/mes_gp.py @@ -2,21 +2,19 @@ Tutorial: https://botorch.org/tutorials/max_value_entropy """ +from abc import ABC + +import numpy as np import torch # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP -from botorch.test_functions import Hartmann -from botorch.test_functions import Branin +from botorch.test_functions import Branin, Hartmann +from botorch.utils.transforms import normalize, standardize from gpytorch.mlls import ExactMarginalLogLikelihood -from botorch.utils.transforms import standardize, normalize -from torch.optim import Adam -from torch.nn import Linear -from torch.nn import MSELoss -from torch.nn import Sequential, ReLU, Dropout from torch import tensor -import numpy as np -from abc import ABC +from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential +from torch.optim import Adam bounds = torch.tensor(Branin._bounds).T # train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(10, 2) diff --git a/playground/botorch/mes_gp_debug.py b/playground/botorch/mes_gp_debug.py index 4c0ff68fe..06c5a3ed6 100644 --- a/playground/botorch/mes_gp_debug.py +++ b/playground/botorch/mes_gp_debug.py @@ -1,25 +1,24 @@ +from abc import ABC + +import gpytorch +import numpy as np import torch # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann from gpytorch.mlls import ExactMarginalLogLikelihood -from torch.optim import Adam -from torch.nn import Linear -from torch.nn import MSELoss -from torch.nn import Sequential, ReLU, Dropout -from torch import tensor -import numpy as np -from abc import ABC -import gpytorch from gpytorch.priors.torch_priors import GammaPrior - +from torch import tensor +from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential +from torch.optim import Adam neg_hartmann6 = Hartmann(dim=6, negate=True) train_x = torch.rand(10, 6) train_y = neg_hartmann6(train_x).unsqueeze(-1) + # We will use the simplest form of GP model, exact inference class ExactGPModel(gpytorch.models.ExactGP): def __init__(self, train_x, train_y, likelihood): @@ -49,9 +48,9 @@ def forward(self, x): gp.train() likelihood.train() -from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal -from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.models.utils import add_output_dim +from botorch.posteriors.gpytorch import GPyTorchPosterior +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood diff --git a/playground/botorch/mes_nn_bao_fix.py b/playground/botorch/mes_nn_bao_fix.py index c19740903..c4f7de6d0 100644 --- a/playground/botorch/mes_nn_bao_fix.py +++ b/playground/botorch/mes_nn_bao_fix.py @@ -1,17 +1,15 @@ +from abc import ABC + +import numpy as np import torch # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann from gpytorch.mlls import ExactMarginalLogLikelihood -from torch.optim import Adam -from torch.nn import Linear -from torch.nn import MSELoss -from torch.nn import Sequential, ReLU, Dropout from torch import tensor -import numpy as np -from abc import ABC - +from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential +from torch.optim import Adam neg_hartmann6 = Hartmann(dim=6, negate=True) @@ -57,8 +55,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model -from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal from botorch.posteriors.gpytorch import GPyTorchPosterior +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal # from botorch.posteriors. from torch.distributions import Normal diff --git a/playground/botorch/mes_nn_hardcode_gpVal.py b/playground/botorch/mes_nn_hardcode_gpVal.py index c3c2e18af..6320d4f05 100644 --- a/playground/botorch/mes_nn_hardcode_gpVal.py +++ b/playground/botorch/mes_nn_hardcode_gpVal.py @@ -1,17 +1,15 @@ +from abc import ABC + +import numpy as np import torch # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann from gpytorch.mlls import ExactMarginalLogLikelihood -from torch.optim import Adam -from torch.nn import Linear -from torch.nn import MSELoss -from torch.nn import Sequential, ReLU, Dropout from torch import tensor -import numpy as np -from abc import ABC - +from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential +from torch.optim import Adam """ The input to the mvn is mean of dim 1x20 and covar of dim 1 x 20 x 20 @@ -58,8 +56,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model -from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal from botorch.posteriors.gpytorch import GPyTorchPosterior +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal class NN_Model(Model): diff --git a/playground/botorch/mes_nn_like_gp.py b/playground/botorch/mes_nn_like_gp.py index f6478f532..d0664a342 100644 --- a/playground/botorch/mes_nn_like_gp.py +++ b/playground/botorch/mes_nn_like_gp.py @@ -1,24 +1,21 @@ +from abc import ABC + +import numpy as np import torch +from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP -from botorch.test_functions import Hartmann -from gpytorch.mlls import ExactMarginalLogLikelihood -from torch.optim import Adam -from torch.nn import Linear -from torch.nn import MSELoss -from torch.nn import Sequential, ReLU, Dropout -from torch import tensor -import numpy as np -from abc import ABC -from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model -from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.test_functions import Hartmann +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.mlls import ExactMarginalLogLikelihood # from botorch.posteriors. -from torch import distributions - +from torch import distributions, tensor +from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential +from torch.optim import Adam """ Initialise the Dataset diff --git a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py index 1b9aae774..2c75fd6a4 100644 --- a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py +++ b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py @@ -1,24 +1,21 @@ +from abc import ABC + +import numpy as np import torch +from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP -from botorch.test_functions import Hartmann -from gpytorch.mlls import ExactMarginalLogLikelihood -from torch.optim import Adam -from torch.nn import Linear -from torch.nn import MSELoss -from torch.nn import Sequential, ReLU, Dropout -from torch import tensor -import numpy as np -from abc import ABC - -from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model -from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.test_functions import Hartmann +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.mlls import ExactMarginalLogLikelihood # from botorch.posteriors. -from torch import distributions +from torch import distributions, tensor +from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential +from torch.optim import Adam """ Create the Dataset diff --git a/playground/botorch/mes_var_deepKernel.py b/playground/botorch/mes_var_deepKernel.py index e2bf8775d..f712eaaf0 100644 --- a/playground/botorch/mes_var_deepKernel.py +++ b/playground/botorch/mes_var_deepKernel.py @@ -5,18 +5,18 @@ Tutorial somewhat inspored by: https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/SVGP_Regression_CUDA.html """ import math +import os +import urllib.request +from math import floor -# import tqdm -import torch import gpytorch -from tqdm.notebook import tqdm -import urllib.request -import os -from scipy.io import loadmat -from math import floor +# import tqdm +import torch from botorch.test_functions import Hartmann +from scipy.io import loadmat from torch.utils.data import DataLoader, Dataset +from tqdm.notebook import tqdm class Data(Dataset): @@ -187,8 +187,8 @@ def train(epoch): from botorch.models import SingleTaskGP -from gpytorch.distributions import MultivariateNormal from botorch.posteriors import GPyTorchPosterior +from gpytorch.distributions import MultivariateNormal class myGPModel(SingleTaskGP): @@ -215,8 +215,8 @@ def posterior( from botorch.acquisition.max_value_entropy_search import ( - qMaxValueEntropy, qLowerBoundMaxValueEntropy, + qMaxValueEntropy, ) proxy = myGPModel(model, train_x, train_y.unsqueeze(-1)) diff --git a/scripts/.old.eval_gflownet.py b/scripts/.old.eval_gflownet.py index a8ad64aa8..4222fead0 100644 --- a/scripts/.old.eval_gflownet.py +++ b/scripts/.old.eval_gflownet.py @@ -1,31 +1,31 @@ """ Computes evaluation metrics from a pre-trained GFlowNet model. """ -from argparse import ArgumentParser import copy import gzip import heapq import itertools import os import pickle +import time +from argparse import ArgumentParser from collections import defaultdict from itertools import count, product from pathlib import Path -import yaml -import time import numpy as np import pandas as pd -from scipy.stats import norm -from tqdm import tqdm import torch import torch.nn as nn -from torch.distributions.categorical import Categorical - +import yaml +from aptamers import AptamerSeq from oracle import Oracle +from scipy.stats import norm +from torch.distributions.categorical import Categorical +from tqdm import tqdm from utils import get_config, namespace2dict, numpy2python -from gflownet import GFlowNetAgent, make_mlp, batch2dict -from aptamers import AptamerSeq + +from gflownet import GFlowNetAgent, batch2dict, make_mlp # Float and Long tensors _dev = [torch.device("cpu")] diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index 88358dfec..67af72a34 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -1,14 +1,16 @@ """ Computes evaluation metrics and plots from a pre-trained GFlowNet model. """ +import sys from argparse import ArgumentParser +from pathlib import Path + import hydra +import torch from hydra import compose, initialize, initialize_config_dir from omegaconf import OmegaConf -from pathlib import Path -import torch from torch.distributions.categorical import Categorical -import sys + from gflownet.gflownet import GFlowNetAgent, Policy diff --git a/scripts/get_comet_results.py b/scripts/get_comet_results.py index 9c8327eb7..6dde5d4b7 100644 --- a/scripts/get_comet_results.py +++ b/scripts/get_comet_results.py @@ -1,13 +1,14 @@ """ Script to retrieve results from Comet.ml """ -from comet_ml import Experiment -from comet_ml.api import API -from tqdm import tqdm from argparse import ArgumentParser from pathlib import Path + import numpy as np import pandas as pd +from comet_ml import Experiment +from comet_ml.api import API +from tqdm import tqdm def add_args(parser): diff --git a/scripts/oracle_annotate.py b/scripts/oracle_annotate.py index 1bd6dcf1d..2236eb36f 100644 --- a/scripts/oracle_annotate.py +++ b/scripts/oracle_annotate.py @@ -2,9 +2,8 @@ Annotates a data set with an oracle """ import hydra -from omegaconf import DictConfig, ListConfig, OmegaConf import pandas as pd - +from omegaconf import DictConfig, ListConfig, OmegaConf from oracle import Oracle diff --git a/scripts/oracle_sampler.py b/scripts/oracle_sampler.py index 571dd3fd9..34eac2770 100644 --- a/scripts/oracle_sampler.py +++ b/scripts/oracle_sampler.py @@ -1,18 +1,17 @@ """ Script to create data set of with nupack labels. """ -from argparse import ArgumentParser import os import pickle -from pathlib import Path -import yaml import time +from argparse import ArgumentParser +from pathlib import Path import numpy as np import pandas as pd -from tqdm import tqdm - +import yaml from oracle import Oracle +from tqdm import tqdm from utils import get_config, namespace2dict, numpy2python diff --git a/setup_conformer.sh b/setup_conformer.sh index c9452ecfb..0ef7c720b 100644 --- a/setup_conformer.sh +++ b/setup_conformer.sh @@ -21,4 +21,4 @@ python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-le # Conditional requirements python -m pip install wandb matplotlib plotly gdown # Dev packages -# python -m pip install black flake8 isort pylint ipdb jupyter pytest +# python -m pip install black flake8 isort pylint ipdb jupyter pytest pytest-repeat diff --git a/setup_gflownet.sh b/setup_gflownet.sh index 637cfc208..d7ce20b16 100644 --- a/setup_gflownet.sh +++ b/setup_gflownet.sh @@ -18,4 +18,4 @@ python -m pip install numpy pandas hydra-core tqdm torchtyping scikit-learn # Conditional requirements to run python -m pip install wandb matplotlib plotly pymatgen # Dev packages -# python -m pip install black flake8 isort pylint ipdb jupyter pytest +# python -m pip install black flake8 isort pylint ipdb jupyter pytest pytest-repeat diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py new file mode 100644 index 000000000..50df7c127 --- /dev/null +++ b/tests/gflownet/envs/common.py @@ -0,0 +1,222 @@ +import hydra +import numpy as np +import pytest +import torch +import yaml +from hydra import compose, initialize + + +def test__all_env_common(env): + test__get_parents_step_get_mask__are_compatible(env) + test__sample_backwards_reaches_source(env) + test__state_conversions_are_reversible(env) + test__get_parents__returns_no_parents_in_initial_state(env) + test__gflownet_minimal_runs(env) + test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env) + test__get_parents__returns_same_state_and_eos_if_done(env) + test__step__returns_same_state_action_and_invalid_if_done(env) + test__actions2indices__returns_expected_tensor(env) + + +def test__continuous_env_common(env): + # test__state_conversions_are_reversible(env) + test__get_parents__returns_no_parents_in_initial_state(env) + # test__gflownet_minimal_runs(env) + # test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env) + test__get_parents__returns_same_state_and_eos_if_done(env) + test__step__returns_same_state_action_and_invalid_if_done(env) + test__actions2indices__returns_expected_tensor(env) + + +@pytest.mark.repeat(100) +def test__get_parents_step_get_mask__are_compatible(env): + env = env.reset() + n_actions = 0 + while not env.done: + state = env.state + # Sample random action + mask_invalid = torch.unsqueeze( + torch.BoolTensor(env.get_mask_invalid_actions_forward()), 0 + ) + random_policy = torch.unsqueeze( + torch.tensor(env.random_policy_output, dtype=env.float), 0 + ) + actions, _ = env.sample_actions( + policy_outputs=random_policy, mask_invalid_actions=mask_invalid + ) + next_state, action, valid = env.step(actions[0]) + if valid is False: + continue + n_actions += 1 + assert n_actions <= env.max_traj_length + assert env.n_actions == n_actions + parents, parents_a = env.get_parents() + assert state in parents + assert len(parents) == len(parents_a) + for p, p_a in zip(parents, parents_a): + mask = env.get_mask_invalid_actions_forward(p, False) + assert p_a in env.action_space + assert mask[env.action_space.index(p_a)] is False + + +@pytest.mark.repeat(100) +def test__sample_backwards_reaches_source(env, n=100): + if hasattr(env, "get_all_terminating_states"): + x = env.get_all_terminating_states() + elif hasattr(env, "get_uniform_terminating_states"): + x = env.get_uniform_terminating_states(n) + else: + print( + """ + Environment does not have neither get_all_terminating_states() nor + get_uniform_terminating_states(). Backward sampling will not be tested. + """ + ) + return + for state in x: + env.set_state(state, done=True) + n_actions = 0 + while True: + if env.state == env.source: + assert True + break + parents, parents_a = env.get_parents() + assert len(parents) > 0 + # Sample random parent + parent = parents[np.random.permutation(len(parents))[0]] + env.set_state(parent) + n_actions += 1 + assert n_actions <= env.max_traj_length + + +@pytest.mark.repeat(100) +def test__state_conversions_are_reversible(env): + env = env.reset() + while not env.done: + state = env.state + assert state == env.policy2state(env.state2policy(state)) + for el1, el2 in zip(state, env.readable2state(env.state2readable(state))): + assert np.isclose(el1, el2) + # Sample random action + mask_invalid = torch.unsqueeze( + torch.BoolTensor(env.get_mask_invalid_actions_forward()), 0 + ) + random_policy = torch.unsqueeze( + torch.tensor(env.random_policy_output, dtype=env.float), 0 + ) + actions, _ = env.sample_actions( + policy_outputs=random_policy, mask_invalid_actions=mask_invalid + ) + env.step(actions[0]) + + +def test__get_parents__returns_no_parents_in_initial_state(env): + parents, actions = env.get_parents() + assert len(parents) == 0 + assert len(actions) == 0 + + +def test__default_config_equals_default_args(env, env_config_path): + with open(env_config_path, "r") as f: + config_env = yaml.safe_load(f) + env_config = hydra.utils.instantiate(config) + assert True + + +def test__gflownet_minimal_runs(env): + # Load config + with initialize(version_base="1.1", config_path="../../../config", job_name="xxx"): + config = compose(config_name="tests") + # Logger + logger = hydra.utils.instantiate(config.logger, config, _recursive_=False) + # Proxy + proxy = hydra.utils.instantiate( + config.proxy, device=config.device, float_precision=config.float_precision + ) + # Set proxy in env + env.proxy = proxy + # No buffer + config.env.buffer.train = None + config.env.buffer.test = None + # Set 1 training step + config.gflownet.optimizer.n_train_steps = 1 + # GFlowNet agent + gflownet = hydra.utils.instantiate( + config.gflownet, + device=config.device, + float_precision=config.float_precision, + env=env, + buffer=config.env.buffer, + logger=logger, + ) + gflownet.train() + assert True + + +@pytest.mark.repeat(100) +def test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env): + env = env.reset() + while not env.done: + policy_outputs = torch.unsqueeze(torch.tensor(env.random_policy_output), 0) + mask_invalid = env.get_mask_invalid_actions_forward() + valid_actions = [a for a, m in zip(env.action_space, mask_invalid) if not m] + masks_invalid_torch = torch.unsqueeze(torch.BoolTensor(mask_invalid), 0) + actions, logprobs_sa = env.sample_actions( + policy_outputs=policy_outputs, mask_invalid_actions=masks_invalid_torch + ) + actions_torch = torch.tensor(actions) + logprobs_glp = env.get_logprobs( + policy_outputs=policy_outputs, + is_forward=True, + actions=actions_torch, + states_target=None, + mask_invalid_actions=masks_invalid_torch, + ) + action = actions[0] + assert action in valid_actions + assert torch.equal(logprobs_sa, logprobs_glp) + env.step(action) + + +def test__get_parents__returns_no_parents_in_initial_state(env): + env.reset() + parents, actions = env.get_parents() + assert len(parents) == 0 + assert len(actions) == 0 + + +def test__get_parents__returns_same_state_and_eos_if_done(env): + env.set_state(env.state, done=True) + parents, actions = env.get_parents() + assert parents == [env.state] + assert actions == [env.action_space[-1]] + + +@pytest.mark.repeat(10) +def test__step__returns_same_state_action_and_invalid_if_done(env): + # Sample random action + mask_invalid = torch.unsqueeze( + torch.BoolTensor(env.get_mask_invalid_actions_forward()), 0 + ) + random_policy = torch.unsqueeze( + torch.tensor(env.random_policy_output, dtype=env.float), 0 + ) + actions, _ = env.sample_actions( + policy_outputs=random_policy, mask_invalid_actions=mask_invalid + ) + action = actions[0] + env.set_state(env.state, done=True) + next_state, action_step, valid = env.step(action) + assert next_state == env.state + assert action_step == action + assert valid is False + + +@pytest.mark.repeat(10) +def test__actions2indices__returns_expected_tensor(env, batch_size=100): + action_space = env.action_space_torch + indices_rand = torch.randint(low=0, high=action_space.shape[0], size=(batch_size,)) + actions = action_space[indices_rand, :] + action_indices = env.actions2indices(actions) + assert torch.equal(action_indices, indices_rand) + assert torch.equal(action_space[action_indices], actions) diff --git a/tests/gflownet/envs/test_crystals.py b/tests/gflownet/envs/test_crystals.py index 274353857..a9d05e1a4 100644 --- a/tests/gflownet/envs/test_crystals.py +++ b/tests/gflownet/envs/test_crystals.py @@ -1,3 +1,4 @@ +import common import pytest import torch @@ -82,7 +83,7 @@ def test__reset(env): (84, 3, 8), ], ) -def test__get_actions_space__returns_correct_number_of_actions( +def test__get_action_space__returns_correct_number_of_actions( elements, min_atom_i, max_atom_i ): environment = Crystal( @@ -90,17 +91,17 @@ def test__get_actions_space__returns_correct_number_of_actions( ) exp_n_actions = elements * (max_atom_i - min_atom_i + 1) + 1 - assert len(environment.get_actions_space()) == exp_n_actions + assert len(environment.get_action_space()) == exp_n_actions @pytest.mark.parametrize( "elements", [[1, 2, 3, 4], [1, 12, 84], [42]], ) -def test__get_actions_space__returns_actions_for_each_element(elements): +def test__get_action_space__returns_actions_for_each_element(elements): environment = Crystal(elements=elements) - elements_in_action_space = set(e for e, n in environment.get_actions_space()) + elements_in_action_space = set(e for e, n in environment.get_action_space()) exp_elements_with_eos = set(elements + [-1]) assert elements_in_action_space == exp_elements_with_eos @@ -117,7 +118,7 @@ def test__get_actions_space__returns_actions_for_each_element(elements): (84, 3, 8), ], ) -def test__get_actions_space__returns_actions_for_each_step_size( +def test__get_action_space__returns_actions_for_each_step_size( elements, min_atom_i, max_atom_i ): environment = Crystal( @@ -125,7 +126,7 @@ def test__get_actions_space__returns_actions_for_each_step_size( ) step_sizes_in_action_space = set( - n for e, n in environment.get_actions_space()[:-1] + n for e, n in environment.get_action_space()[:-1] ) # skip eos exp_step_sizes = set(range(min_atom_i, max_atom_i + 1)) @@ -154,10 +155,7 @@ def test__get_mask_invalid_actions__already_set_elements_are_masked(env, state): def test__get_parents__returns_no_parents_in_initial_state(env): - parents, actions = env.get_parents() - - assert len(parents) == 0 - assert len(actions) == 0 + return common.test__get_parents__returns_no_parents_in_initial_state(env) def test__get_parents__returns_parents_after_step(env): diff --git a/tests/gflownet/envs/test_ctorus.py b/tests/gflownet/envs/test_ctorus.py new file mode 100644 index 000000000..7c469430a --- /dev/null +++ b/tests/gflownet/envs/test_ctorus.py @@ -0,0 +1,28 @@ +import common +import numpy as np +import pytest +import torch + +from gflownet.envs.ctorus import ContinuousTorus + + +@pytest.fixture +def env(): + return ContinuousTorus(n_dim=2, length_traj=3) + + +@pytest.mark.parametrize( + "action_space", + [ + [ + (0.0, 0.0), + (np.inf, np.inf), + ], + ], +) +def test__get_action_space__returns_expected(env, action_space): + assert set(action_space) == set(env.action_space) + + +def test__continuous_env_common(env): + return common.test__continuous_env_common(env) diff --git a/tests/gflownet/envs/test_grid.py b/tests/gflownet/envs/test_grid.py new file mode 100644 index 000000000..1cd2f9e44 --- /dev/null +++ b/tests/gflownet/envs/test_grid.py @@ -0,0 +1,102 @@ +import common +import pytest +import torch + +from gflownet.envs.grid import Grid + + +@pytest.fixture +def env(): + return Grid(n_dim=3, length=5, cell_min=-1.0, cell_max=1.0) + + +@pytest.fixture +def env_extended_action_space_2d(): + return Grid( + n_dim=2, + length=5, + max_increment=2, + max_dim_per_action=-1, + cell_min=-1.0, + cell_max=1.0, + ) + + +@pytest.fixture +def env_extended_action_space_3d(): + return Grid( + n_dim=3, + length=5, + max_increment=2, + max_dim_per_action=3, + cell_min=-1.0, + cell_max=1.0, + ) + + +@pytest.fixture +def env_default(): + return Grid() + + +@pytest.fixture +def config_path(): + return "../../../config/env/grid.yaml" + + +@pytest.mark.parametrize( + "state, state2oracle", + [ + ( + [0, 0, 0], + [-1.0, -1.0, -1.0], + ), + ( + [4, 4, 4], + [1.0, 1.0, 1.0], + ), + ( + [1, 2, 3], + [-0.5, 0.0, 0.5], + ), + ( + [4, 0, 1], + [1.0, -1.0, -0.5], + ), + ], +) +def test__state2oracle__returns_expected(env, state, state2oracle): + assert state2oracle == env.state2oracle(state) + + +@pytest.mark.parametrize( + "states, statebatch2oracle", + [ + ( + [[0, 0, 0], [4, 4, 4], [1, 2, 3], [4, 0, 1]], + [[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0], [-0.5, 0.0, 0.5], [1.0, -1.0, -0.5]], + ), + ], +) +def test__statebatch2oracle__returns_expected(env, states, statebatch2oracle): + assert torch.equal(torch.Tensor(statebatch2oracle), env.statebatch2oracle(states)) + + +@pytest.mark.parametrize( + "action_space", + [ + [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1), (0, 2), (1, 2), (2, 2)], + ], +) +def test__get_action_space__returns_expected( + env_extended_action_space_2d, action_space +): + assert set(action_space) == set(env_extended_action_space_2d.action_space) + + +def test__all_env_common(env): + return common.test__all_env_common(env) + + +def test__all_env_common(env_extended_action_space_3d): + return common.test__all_env_common(env_extended_action_space_3d) diff --git a/tests/gflownet/envs/test_htorus.py b/tests/gflownet/envs/test_htorus.py new file mode 100644 index 000000000..1b3bd7cfd --- /dev/null +++ b/tests/gflownet/envs/test_htorus.py @@ -0,0 +1,29 @@ +import common +import numpy as np +import pytest +import torch + +from gflownet.envs.htorus import HybridTorus + + +@pytest.fixture +def env(): + return HybridTorus(n_dim=2, length_traj=3) + + +@pytest.mark.parametrize( + "action_space", + [ + [ + (0, 0), + (1, 0), + (2, 0), + ], + ], +) +def test__get_action_space__returns_expected(env, action_space): + assert set(action_space) == set(env.action_space) + + +def test__continuous_env_common(env): + return common.test__continuous_env_common(env) diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index b72855d62..105cc9c31 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -1,8 +1,9 @@ +import common +import numpy as np +import pymatgen.symmetry.groups as pmgg import pytest import torch -import numpy as np -import pymatgen.symmetry.groups as pmgg from gflownet.envs.spacegroup import SpaceGroup @@ -19,7 +20,7 @@ def test__environment__initializes_properly(): def test__environment__action_space_has_eos(): env = SpaceGroup() - assert (env.eos, 0) in env.action_space + assert env.eos in env.action_space @pytest.mark.parametrize( @@ -72,7 +73,9 @@ def test__environment__action_space_has_eos(): ), ], ) -def test__get_mask_invalid_actions_forward__masks_expected_action(env, state, action, expected): +def test__get_mask_invalid_actions_forward__masks_expected_action( + env, state, action, expected +): assert action in env.action_space mask = env.get_mask_invalid_actions_forward(state, False) assert mask[env.action_space.index(action)] == expected @@ -100,23 +103,7 @@ def test__state2readable2state(env, state): ) -# TODO: make common to all environments -def test__get_parents_step_get_mask__are_compatible(env, n=100): - for traj in range(n): - env = env.reset() - while not env.done: - mask_invalid = env.get_mask_invalid_actions_forward() - valid_actions = [a for a, m in zip(env.action_space, mask_invalid) if not m] - action = tuple(np.random.permutation(valid_actions)[0]) - env.step(action) - parents, parents_a = env.get_parents() - assert len(parents) == len(parents_a) - for p, p_a in zip(parents, parents_a): - mask = env.get_mask_invalid_actions_forward(p, False) - assert p_a in env.action_space - assert mask[env.action_space.index(p_a)] == False - - +@pytest.mark.skip(reason="Takes considerable time") def test__states_are_compatible_with_pymatgen(env): for idx in range(env.n_space_groups): env = env.reset() @@ -130,8 +117,5 @@ def test__states_are_compatible_with_pymatgen(env): assert sg.point_group in point_groups -# TODO: make common to all environments -def test__get_parents__returns_no_parents_in_initial_state(env): - parents, actions = env.get_parents() - assert len(parents) == 0 - assert len(actions) == 0 +def test__all_env_common(env): + return common.test__all_env_common(env) diff --git a/tests/gflownet/envs/test_torus.py b/tests/gflownet/envs/test_torus.py new file mode 100644 index 000000000..753f21e91 --- /dev/null +++ b/tests/gflownet/envs/test_torus.py @@ -0,0 +1,78 @@ +import common +import numpy as np +import pytest +import torch + +from gflownet.envs.torus import Torus + + +@pytest.fixture +def env(): + return Torus(n_dim=3, n_angles=5) + + +@pytest.fixture +def env_extended_action_space_2d(): + return Torus( + n_dim=2, + n_angles=5, + max_increment=2, + max_dim_per_action=-1, + ) + + +@pytest.fixture +def env_extended_action_space_3d(): + return Torus( + n_dim=3, + n_angles=5, + max_increment=2, + max_dim_per_action=2, + ) + + +@pytest.mark.parametrize( + "action_space", + [ + [ + (-2, -2), + (-2, -1), + (-2, 0), + (-2, 1), + (-2, 2), + (-1, -2), + (-1, -1), + (-1, 0), + (-1, 1), + (-1, 2), + (0, -2), + (0, -1), + (0, 0), + (0, 1), + (0, 2), + (1, -2), + (1, -1), + (1, 0), + (1, 1), + (1, 2), + (2, -2), + (2, -1), + (2, 0), + (2, 1), + (2, 2), + (3, 3), + ], + ], +) +def test__get_action_space__returns_expected( + env_extended_action_space_2d, action_space +): + assert set(action_space) == set(env_extended_action_space_2d.action_space) + + +def test__all_env_common(env): + return common.test__all_env_common(env) + + +def test__all_env_common(env_extended_action_space_3d): + return common.test__all_env_common(env_extended_action_space_3d) diff --git a/tests/gflownet/proxy/test_molecule.py b/tests/gflownet/proxy/test_molecule.py index f844dda60..02c27e76f 100644 --- a/tests/gflownet/proxy/test_molecule.py +++ b/tests/gflownet/proxy/test_molecule.py @@ -1,10 +1,9 @@ -import pytest - import numpy as np +import pytest import torch -from gflownet.utils.molecule.conformer_base import get_dummy_ad_conf_base from gflownet.proxy.molecule import TorchANIMoleculeEnergy +from gflownet.utils.molecule.conformer_base import get_dummy_ad_conf_base @pytest.fixture()