-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Make adding new Policy Models flexible #327
base: main
Are you sure you want to change the base?
Changes from all commits
98e4d7e
1388756
7c8a517
4ecd865
d50aeea
1508e4b
9cdf18e
8af7f82
9474150
83a4904
e254cf2
d30c3f2
2ac9af5
7e02200
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ playground/ | |
!docs/requirements-docs.txt | ||
.DS_Store | ||
docs/_build/ | ||
logs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
_target_: gflownet.policy.cnn.CNNPolicy | ||
|
||
shared: null | ||
|
||
forward: | ||
n_layers: 2 | ||
channels: [16, 32] | ||
kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width) | ||
strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) | ||
checkpoint: null | ||
reload_ckpt: False | ||
|
||
backward: | ||
shared_weights: True | ||
checkpoint: null | ||
reload_ckpt: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,7 @@ def __init__( | |
height: int = 20, | ||
pieces: List = ["I", "J", "L", "O", "S", "T", "Z"], | ||
rotations: List = [0, 90, 180, 270], | ||
flatten: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we move the flattening from the environment to the policy, then we don't need this. |
||
allow_redundant_rotations: bool = False, | ||
allow_eos_before_full: bool = False, | ||
**kwargs, | ||
|
@@ -87,6 +88,7 @@ def __init__( | |
self.height = height | ||
self.pieces = pieces | ||
self.rotations = rotations | ||
self.flatten = flatten | ||
self.allow_redundant_rotations = allow_redundant_rotations | ||
self.allow_eos_before_full = allow_eos_before_full | ||
self.max_pieces_per_type = 100 | ||
|
@@ -307,7 +309,9 @@ def states2policy( | |
A tensor containing all the states in the batch. | ||
""" | ||
states = tint(states, device=self.device, int_type=self.int) | ||
return self.states2proxy(states).flatten(start_dim=1).to(self.float) | ||
if self.flatten: | ||
return self.states2proxy(states).flatten(start_dim=1).to(self.float) | ||
return self.states2proxy(states).to(self.float) | ||
Comment on lines
+312
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexhernandezgarcia This is a temporary solution to make the CNN policy work on Tetris env. But normally the flattening should happen inside the model but not in the environment (see my other comments) if you are okay with that, then I can update. |
||
|
||
def state2readable(self, state: Optional[TensorType["height", "width"]] = None): | ||
""" | ||
|
@@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2): | |
linewidth : int | ||
The width of the separation between cells, in pixels. | ||
""" | ||
board = board.clone().numpy() | ||
board = board.clone().cpu().numpy() | ||
height = board.shape[0] * cellsize | ||
width = board.shape[1] * cellsize | ||
board_img = 128 * np.ones( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
from omegaconf import OmegaConf | ||
from torch import nn | ||
|
||
from gflownet.policy.base import Policy | ||
|
||
|
||
class CNNPolicy(Policy): | ||
def __init__(self, config, env, device, float_precision, base=None): | ||
self.env = env | ||
super().__init__( | ||
config=config, | ||
env=env, | ||
device=device, | ||
float_precision=float_precision, | ||
base=base, | ||
) | ||
|
||
def make_cnn(self): | ||
""" | ||
Defines an CNN with no top layer activation | ||
""" | ||
if self.shared_weights and self.base is not None: | ||
layers = list(self.base.model.children())[:-1] | ||
last_layer = nn.Linear( | ||
self.base.model[-1].in_features, self.base.model[-1].out_features | ||
) | ||
|
||
model = nn.Sequential(*layers, last_layer).to(self.device) | ||
return model | ||
|
||
current_channels = 1 | ||
conv_module = nn.Sequential() | ||
|
||
if len(self.kernel_sizes) != self.n_layers: | ||
raise ValueError( | ||
f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}" | ||
) | ||
|
||
for i in range(self.n_layers): | ||
conv_module.add_module( | ||
f"conv_{i}", | ||
nn.Conv2d( | ||
in_channels=current_channels, | ||
out_channels=self.channels[i], | ||
kernel_size=tuple(self.kernel_sizes[i]), | ||
stride=tuple(self.strides[i]), | ||
padding=0, | ||
padding_mode="zeros", # Constant zero padding | ||
), | ||
) | ||
conv_module.add_module(f"relu_{i}", nn.ReLU()) | ||
current_channels = self.channels[i] | ||
|
||
dummy_input = torch.ones( | ||
(1, 1, self.env.height, self.env.width) | ||
) # (batch_size, channels, height, width) | ||
try: | ||
in_channels = conv_module(dummy_input).numel() | ||
if in_channels >= 500_000: # TODO: this could better be handled | ||
raise RuntimeWarning( | ||
"Input channels for the dense layer are too big, this will increase number of parameters" | ||
) | ||
except RuntimeError as e: | ||
raise RuntimeError( | ||
"Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions." | ||
) from e | ||
|
||
model = nn.Sequential( | ||
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) | ||
) | ||
return model.to(self.device) | ||
|
||
def parse_config(self, config): | ||
engmubarak48 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().parse_config(config) | ||
if config is None: | ||
config = OmegaConf.create() | ||
self.checkpoint = config.get("checkpoint", None) | ||
self.shared_weights = config.get("shared_weights", False) | ||
self.reload_ckpt = config.get("reload_ckpt", False) | ||
self.n_layers = config.get("n_layers", 3) | ||
self.channels = config.get("channels", [16] * self.n_layers) | ||
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) | ||
self.strides = config.get("strides", [(1, 1)] * self.n_layers) | ||
|
||
def instantiate(self): | ||
self.model = self.make_cnn() | ||
self.is_model = True | ||
|
||
def __call__(self, states): | ||
states = states.unsqueeze(1) # (batch_size, channels, height, width) | ||
return self.model(states) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from omegaconf import OmegaConf | ||
from torch import nn | ||
|
||
from gflownet.policy.base import Policy | ||
|
||
|
||
class MLPPolicy(Policy): | ||
def __init__(self, config, env, device, float_precision, base=None): | ||
super().__init__( | ||
config=config, | ||
env=env, | ||
device=device, | ||
float_precision=float_precision, | ||
base=base, | ||
) | ||
|
||
def make_mlp(self, activation): | ||
""" | ||
Defines an MLP with no top layer activation | ||
If share_weight == True, | ||
baseModel (the model with which weights are to be shared) must be provided | ||
Args | ||
---- | ||
layers_dim : list | ||
Dimensionality of each layer | ||
activation : Activation | ||
Activation function | ||
""" | ||
if self.shared_weights == True and self.base is not None: | ||
mlp = nn.Sequential( | ||
self.base.model[:-1], | ||
nn.Linear( | ||
self.base.model[-1].in_features, self.base.model[-1].out_features | ||
), | ||
) | ||
return mlp | ||
elif self.shared_weights == False: | ||
layers_dim = ( | ||
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] | ||
) | ||
mlp = nn.Sequential( | ||
*( | ||
sum( | ||
[ | ||
[nn.Linear(idim, odim)] | ||
+ ([activation] if n < len(layers_dim) - 2 else []) | ||
for n, (idim, odim) in enumerate( | ||
zip(layers_dim, layers_dim[1:]) | ||
) | ||
], | ||
[], | ||
) | ||
+ self.tail | ||
) | ||
) | ||
return mlp | ||
else: | ||
raise ValueError( | ||
"Base Model must be provided when shared_weights is set to True" | ||
) | ||
|
||
def parse_config(self, config): | ||
super().parse_config(config) | ||
if config is None: | ||
config = OmegaConf.create() | ||
self.checkpoint = config.get("checkpoint", None) | ||
self.shared_weights = config.get("shared_weights", False) | ||
self.n_hid = config.get("n_hid", 128) | ||
self.n_layers = config.get("n_layers", 2) | ||
self.tail = config.get("tail", []) | ||
self.reload_ckpt = config.get("reload_ckpt", False) | ||
|
||
def instantiate(self): | ||
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) | ||
self.is_model = True | ||
|
||
def __call__(self, states): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flattening could happen here or we could add nn.flatten to the model before linear layer. See CNN policy model |
||
return self.model(states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we move the flattening from the environment to the policy, then we don't need this, and we can simply remove it.