Skip to content

Commit

Permalink
Restored MLP and create a seperate class for MLP with residual connec…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
Alberto Gasparin committed Apr 16, 2023
1 parent cdf4856 commit 3c1488a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 12 deletions.
10 changes: 5 additions & 5 deletions examples/two_moons_classification_sngp.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def plot_uncertainty(prob_model, test_data_loader, grid_size=100):
# In this tutorial we will use a deep residual network.

# %%
from fortuna.model.mlp import MLP
from fortuna.model.mlp import DeepResidualNet
import flax.linen as nn

output_dim = 2
model = MLP(
model = DeepResidualNet(
output_dim=output_dim,
activations=(nn.relu, nn.relu, nn.relu, nn.relu, nn.relu, nn.relu),
widths=(128,128,128,128,128,128),
Expand Down Expand Up @@ -131,10 +131,10 @@ def plot_uncertainty(prob_model, test_data_loader, grid_size=100):
# and `WithSpectralNorm`:

# %%
from fortuna.model.mlp import MLPDeepFeatureExtractorSubNet
from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet
from fortuna.model.utils.spectral_norm import WithSpectralNorm

class SNGPDeepFeatureExtractorSubNet(WithSpectralNorm, MLPDeepFeatureExtractorSubNet):
class SNGPDeepFeatureExtractorSubNet(WithSpectralNorm, DeepResidualFeatureExtractorSubNet):
pass

# %% [markdown]
Expand All @@ -147,7 +147,7 @@ class SNGPDeepFeatureExtractorSubNet(WithSpectralNorm, MLPDeepFeatureExtractorSu
from fortuna.model.utils.random_features import RandomFeatureGaussianProcess

from fortuna.model.sngp import SNGPMixin
class SNGPModel(SNGPMixin, MLP):
class SNGPModel(SNGPMixin, DeepResidualNet):
def setup(self):
if len(self.widths) != len(self.activations):
raise Exception(
Expand Down
74 changes: 67 additions & 7 deletions fortuna/model/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@


class MLP(nn.Module):
output_dim: int
widths: Optional[Tuple[int]] = (30, 30)
activations: Optional[Tuple[Callable[[Array], Array]]] = (nn.relu, nn.relu)
dropout: ModuleDef = nn.Dropout
dropout_rate: float = 0.1
dense: ModuleDef = nn.Dense

"""
A multi-layer perceptron (MLP).
Expand All @@ -35,6 +28,12 @@ class MLP(nn.Module):
dense: ModuleDef
Dense module.
"""
output_dim: int
widths: Optional[Tuple[int]] = (30, 30)
activations: Optional[Tuple[Callable[[Array], Array]]] = (nn.relu, nn.relu)
dropout: ModuleDef = nn.Dropout
dropout_rate: float = 0.1
dense: ModuleDef = nn.Dense

def setup(self):
if len(self.widths) != len(self.activations):
Expand All @@ -60,6 +59,29 @@ def __call__(self, x: Array, train: bool = False, **kwargs) -> jnp.ndarray:
return x


class DeepResidualNet(MLP):
"""
A multi-layer perceptron with residual connections
"""
def setup(self):
if len(self.widths) != len(self.activations):
raise Exception(
"`widths` and `activations` must have the same number of elements."
)
self.dfe_subnet = DeepResidualFeatureExtractorSubNet(
dense=self.dense,
widths=self.widths,
activations=self.activations[:-1],
dropout=self.dropout,
dropout_rate=self.dropout_rate,
)
self.output_subnet = MLPOutputSubNet(
dense=self.dense,
activation=self.activations[-1],
output_dim=self.output_dim,
)


class MLPDeepFeatureExtractorSubNet(nn.Module):
widths: Tuple[int]
activations: Tuple[Callable[[Array], Array]]
Expand All @@ -84,6 +106,44 @@ class MLPDeepFeatureExtractorSubNet(nn.Module):
Dropout rate.
"""

@nn.compact
def __call__(self, x: Array, train: bool = False, **kwargs) -> jnp.ndarray:
"""
Forward pass.
Parameters
----------
x: Array
Inputs.
train: bool
Whether it is training or inference.
Returns
-------
jnp.ndarray
Output of the hidden layers.
"""
if hasattr(self, 'spectral_norm'):
dense = self.spectral_norm(self.dense, train=train)
else:
dense = self.dense
dropout = self.dropout(self.dropout_rate)
n_activations = len(self.activations)

def update(i: int, x):
x = dense(self.widths[i], name="hidden" + str(i + 1))(x)
if i < n_activations:
x = self.activations[i](x)
x = dropout(x, deterministic=not train)
return x

x = x.reshape(x.shape[0], -1)
for i in range(0, len(self.widths)):
x = update(i, x)
return x


class DeepResidualFeatureExtractorSubNet(MLPDeepFeatureExtractorSubNet):
@nn.compact
def __call__(self, x: Array, train: bool = False, **kwargs) -> jnp.ndarray:
"""
Expand Down

0 comments on commit 3c1488a

Please sign in to comment.