Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make adding new Policy Models flexible #327

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
60 changes: 15 additions & 45 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,51 +46,6 @@ def instantiate(self):
def __call__(self, states):
return self.model(states)

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)


class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
Expand All @@ -111,6 +66,21 @@ def instantiate(self):
else:
raise "Policy model type not defined"

def instantiate(self):
if self.type == "fixed":
self.model = self.fixed_distribution
self.is_model = False
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
elif self.type == "mlp":
from policy.mlp import MLPPolicy
mlp_policy = MLPPolicy(self.config, self.env, self.device, self.float_precision, self.base)
self.model = mlp_policy.model
self.is_model = mlp_policy.is_model
else:
raise "Policy model type not defined"

def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Expand Down
Empty file added gflownet/policy/cnn.py
Empty file.
Empty file added gflownet/policy/gnn.py
Empty file.
60 changes: 60 additions & 0 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# policy_models/mlp_policy.py

import torch
from torch import nn
from gflownet.policy.base import ModelBase


class MLPPolicy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(config, env, device, float_precision, base)
self.instantiate()

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)
Loading