Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make adding new Policy Models flexible #327

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ playground/
!docs/requirements-docs.txt
.DS_Store
docs/_build/
logs
2 changes: 2 additions & 0 deletions config/env/tetris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ height: 20
pieces: ["I", "J", "L", "O", "S", "T", "Z"]
# Allowed roations
rotations: [0, 90, 180, 270]
# Don't flatten if using CNN
flatten: True
Comment on lines +14 to +15
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move the flattening from the environment to the policy, then we don't need this, and we can simply remove it.

# Other config
allow_redundant_rotations: False
allow_eos_before_full: False
Expand Down
16 changes: 16 additions & 0 deletions config/policy/cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: gflownet.policy.cnn.CNNPolicy

shared: null

forward:
n_layers: 2
channels: [16, 32]
kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width)
strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride)
checkpoint: null
reload_ckpt: False

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
3 changes: 1 addition & 2 deletions config/policy/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
_target_: gflownet.policy.base.Policy
_target_: gflownet.policy.mlp.MLPPolicy

shared: null

forward:
type: mlp
n_hid: 128
n_layers: 2
checkpoint: null
Expand Down
8 changes: 6 additions & 2 deletions gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
height: int = 20,
pieces: List = ["I", "J", "L", "O", "S", "T", "Z"],
rotations: List = [0, 90, 180, 270],
flatten: bool = True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move the flattening from the environment to the policy, then we don't need this.

allow_redundant_rotations: bool = False,
allow_eos_before_full: bool = False,
**kwargs,
Expand All @@ -87,6 +88,7 @@ def __init__(
self.height = height
self.pieces = pieces
self.rotations = rotations
self.flatten = flatten
self.allow_redundant_rotations = allow_redundant_rotations
self.allow_eos_before_full = allow_eos_before_full
self.max_pieces_per_type = 100
Expand Down Expand Up @@ -307,7 +309,9 @@ def states2policy(
A tensor containing all the states in the batch.
"""
states = tint(states, device=self.device, int_type=self.int)
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
if self.flatten:
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
return self.states2proxy(states).to(self.float)
Comment on lines +312 to +314
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexhernandezgarcia This is a temporary solution to make the CNN policy work on Tetris env. But normally the flattening should happen inside the model but not in the environment (see my other comments)

if you are okay with that, then I can update.


def state2readable(self, state: Optional[TensorType["height", "width"]] = None):
"""
Expand Down Expand Up @@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2):
linewidth : int
The width of the separation between cells, in pixels.
"""
board = board.clone().numpy()
board = board.clone().cpu().numpy()
height = board.shape[0] * cellsize
width = board.shape[1] * cellsize
board_img = 128 * np.ones(
Expand Down
2 changes: 1 addition & 1 deletion gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None):

if "corr_prob_traj_rewards" in metrics:
lp_metrics["corr_prob_traj_rewards"] = np.corrcoef(
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt.cpu().numpy()
)[0, 1]

if "var_logrewards_logp" in metrics:
Expand Down
81 changes: 6 additions & 75 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.utils.common import set_device, set_float_precision


class ModelBase(ABC):
class Policy:
def __init__(self, config, env, device, float_precision, base=None):
# Device and float precision
self.device = set_device(device)
Expand All @@ -21,82 +20,14 @@ def __init__(self, config, env, device, float_precision, base=None):
self.base = base

self.parse_config(config)
self.instantiate()

def parse_config(self, config):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
# If config is null, default to uniform
if config is None:
config = OmegaConf.create()
config.type = "uniform"
self.type = config.get("type", "uniform")
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", None)
self.n_layers = config.get("n_layers", None)
self.tail = config.get("tail", [])
if "type" in config:
self.type = config.type
elif self.shared_weights:
self.type = self.base.type
else:
raise "Policy type must be defined if shared_weights is False"

@abstractmethod
def instantiate(self):
pass

def __call__(self, states):
return self.model(states)

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)


class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env, device, float_precision, base)

self.instantiate()

def instantiate(self):
if self.type == "fixed":
Expand All @@ -105,12 +36,12 @@ def instantiate(self):
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
elif self.type == "mlp":
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True
else:
raise "Policy model type not defined"

def __call__(self, states):
return self.model(states)

def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Expand Down
92 changes: 92 additions & 0 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class CNNPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
self.env = env
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)

def make_cnn(self):
"""
Defines an CNN with no top layer activation
"""
if self.shared_weights and self.base is not None:
layers = list(self.base.model.children())[:-1]
last_layer = nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
)

model = nn.Sequential(*layers, last_layer).to(self.device)
return model

current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}"
)

for i in range(self.n_layers):
conv_module.add_module(
f"conv_{i}",
nn.Conv2d(
in_channels=current_channels,
out_channels=self.channels[i],
kernel_size=tuple(self.kernel_sizes[i]),
stride=tuple(self.strides[i]),
padding=0,
padding_mode="zeros", # Constant zero padding
),
)
conv_module.add_module(f"relu_{i}", nn.ReLU())
current_channels = self.channels[i]

dummy_input = torch.ones(
(1, 1, self.env.height, self.env.width)
) # (batch_size, channels, height, width)
try:
in_channels = conv_module(dummy_input).numel()
if in_channels >= 500_000: # TODO: this could better be handled
raise RuntimeWarning(
"Input channels for the dense layer are too big, this will increase number of parameters"
)
except RuntimeError as e:
raise RuntimeError(
"Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions."
) from e

model = nn.Sequential(
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim)
)
return model.to(self.device)

def parse_config(self, config):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.reload_ckpt = config.get("reload_ckpt", False)
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)

def instantiate(self):
self.model = self.make_cnn()
self.is_model = True

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
return self.model(states)
Empty file added gflownet/policy/gnn.py
Empty file.
78 changes: 78 additions & 0 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class MLPPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
self.n_layers = config.get("n_layers", 2)
self.tail = config.get("tail", [])
self.reload_ckpt = config.get("reload_ckpt", False)

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True

def __call__(self, states):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flattening could happen here or we could add nn.flatten to the model before linear layer. See CNN policy model

return self.model(states)
1 change: 0 additions & 1 deletion playground/botorch/mes_exact_deepKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from math import floor

import gpytorch

# import tqdm
import torch
from botorch.test_functions import Hartmann
Expand Down
1 change: 0 additions & 1 deletion playground/botorch/mes_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Branin, Hartmann
Expand Down
Loading
Loading