Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make adding new Policy Models flexible #327

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ playground/
!docs/requirements-docs.txt
.DS_Store
docs/_build/
logs
3 changes: 1 addition & 2 deletions config/policy/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
_target_: gflownet.policy.base.Policy
_target_: gflownet.policy.mlp.MLPPolicy

shared: null

forward:
type: mlp
n_hid: 128
n_layers: 2
checkpoint: null
Expand Down
77 changes: 6 additions & 71 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.utils.common import set_device, set_float_precision


class ModelBase(ABC):
class Policy(ABC):
def __init__(self, config, env, device, float_precision, base=None):
# Device and float precision
self.device = set_device(device)
Expand All @@ -21,82 +20,18 @@ def __init__(self, config, env, device, float_precision, base=None):
self.base = base

self.parse_config(config)
self.instantiate()

def parse_config(self, config):
# If config is null, default to uniform
if config is None:
config = OmegaConf.create()
config.type = "uniform"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", None)
self.n_layers = config.get("n_layers", None)
self.tail = config.get("tail", [])
if "type" in config:
self.type = config.type
elif self.shared_weights:
self.type = self.base.type
else:
raise "Policy type must be defined if shared_weights is False"

@abstractmethod
def instantiate(self):
pass

def __call__(self, states):
return self.model(states)

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)


class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env, device, float_precision, base)

self.instantiate()
self.type = "uniform"

def instantiate(self):
if self.type == "fixed":
Expand All @@ -105,12 +40,12 @@ def instantiate(self):
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
elif self.type == "mlp":
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True
else:
raise "Policy model type not defined"

def __call__(self, states):
return self.model(states)

def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Expand Down
Empty file added gflownet/policy/cnn.py
Empty file.
Empty file added gflownet/policy/gnn.py
Empty file.
78 changes: 78 additions & 0 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class MLPPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
self.is_model = True

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)

def parse_config(self, config):
if config is None:
config = OmegaConf.create()
config.type = "mlp"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
self.n_layers = config.get("n_layers", 2)
self.tail = config.get("tail", [])
self.reload_ckpt = config.get("reload_ckpt", False)

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)

def __call__(self, states):
return self.model(states)
2 changes: 2 additions & 0 deletions scripts/crystal/eval_crystalgflownet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Computes evaluation metrics and plots from a pre-trained GFlowNet model.
"""

import pickle
import shutil
import sys
Expand All @@ -14,6 +15,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent))

from crystalrandom import generate_random_crystals

from gflownet.gflownet import GFlowNetAgent
from gflownet.utils.common import load_gflow_net_from_run_path
from gflownet.utils.policy import parse_policy_config
Expand Down
50 changes: 25 additions & 25 deletions scripts/crystal/eval_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from argparse import ArgumentParser
from pathlib import Path

import pandas as pd
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

Expand Down Expand Up @@ -229,30 +229,30 @@ def main(args):
env.proxy.is_bandgap = False

# Test
# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.test["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"val.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"val.pkl", "wb"))
#
# # Train
# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.train["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"train.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"train.pkl", "wb"))
# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.test["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"val.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"val.pkl", "wb"))
#
# # Train
# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]]
# energies = env.proxy(env.states2proxy(samples))
# df = pd.DataFrame(
# {
# "readable": gflownet.buffer.train["samples"],
# "energies": energies.tolist(),
# }
# )
# df.to_csv(output_dir / f"train.csv")
# dct = {"x": samples, "energy": energies.tolist()}
# pickle.dump(dct, open(output_dir / f"train.pkl", "wb"))

if args.n_samples > 0 and args.n_samples <= 1e5 and not args.random_only:
print(
Expand Down
5 changes: 3 additions & 2 deletions scripts/crystal/sample_uniform_with_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
should be run with the same config as main.py, e.g.
python sample_uniform_with_rewards.py +experiments=crystals/albatross_sg_first logger.do.online=False user=sasha
"""

import pickle
import sys

import hydra
import pandas as pd
from crystalrandom import generate_random_crystals_uniform

from gflownet.utils.common import chdir_random_subdir
from gflownet.utils.policy import parse_policy_config

from crystalrandom import generate_random_crystals_uniform


@hydra.main(config_path="../../config", config_name="main", version_base="1.1")
def main(config):
Expand Down
1 change: 1 addition & 0 deletions scripts/pyxtal/compatibility_sg_n_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
combinations spanned by the N_SYMMETRY_GROUPS, N_SPECIES and MAX_N_ATOMS. The results
are printed to stdout.
"""

import itertools
import time

Expand Down
1 change: 1 addition & 0 deletions scripts/pyxtal/get_n_compatible_for_sg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
spanned by the --max_n_atoms and --max_n_species. The results are written to a file in
--output_dir.
"""

import itertools
import time
from argparse import ArgumentParser
Expand Down
1 change: 1 addition & 0 deletions scripts/pyxtal/pyxtal_vs_pymatgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A simple script to determine which space group symbols are different in pyxtal and
pymatgen.
"""

from argparse import ArgumentParser

from pymatgen.symmetry.groups import (
Expand Down