Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sngp #28

Merged
merged 8 commits into from
Apr 25, 2023
Merged

Sngp #28

merged 8 commits into from
Apr 25, 2023

Conversation

albertogaspar
Copy link
Contributor

Implemenation of Spectral-normalized Neural Gaussian Process.

Pull request type

Please check the type of change your PR introduces:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

What is the current behavior?

What is the new behavior?

New uncertainty quantification method that can be used with ResNet-like models.

A large part of the SNGP implementation (spectral_norm.py and random_features.py) has been adapted from the Jax (and Tensorflow) implementations in Edward2 and uncertainty-baselines.

Other information

  • Callbacks have been moved to their own package under prob_model instead of being within prob_model.fit_config

Copy link
Contributor

@gianlucadetommaso gianlucadetommaso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You never disappoint 🙂 Such a great work!

I left several comments, mostly minor. My major concern is about simplicity of usage. Can we simplify the user experience? See the related comments.

docs/source/references/prob_model/callbacks.rst Outdated Show resolved Hide resolved
docs/source/references/prob_model/callbacks.rst Outdated Show resolved Hide resolved
docs/source/references/prob_model/callbacks.rst Outdated Show resolved Hide resolved
examples/two_moons_classification_sngp.pct.py Show resolved Hide resolved
block_cls=self.block_cls,
num_filters=self.num_filters,
dtype=self.dtype,
activation=self.activation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the MLP model above, there was a [:-1] here. Just to raise attention in case of bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which MLP model above? This is a ResNet model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here. Isn't the deep feature extractor coming from an MLP?

pass


class ResNetSNGP(SNGPMixin, ResNet):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should't ResNets be actually the only model SNGPs are allowed to work with? In such case, why in the notebook example above are we doing this we an MLP?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second note, do you think ResNetSNGP is better than SNGPResNet? Same comment for the WideResNet models below.

Copy link
Contributor Author

@albertogaspar albertogaspar Apr 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the notebook we used an MLP with residual connections.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe SNGPResNet is better, I'll change that

# Default field value for kwargs, to be used for data class declaration.
default_kwarg_dict = lambda: dataclasses.field(default_factory=dict)

SUPPORTED_LIKELIHOOD = ("binary_logistic", "poisson", "gaussian")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, can we make this better?

self.covmat_layer = LaplaceRandomFeatureCovariance(
hidden_features=self.hidden_features, **self.covmat_kwargs
)
# pylint:enable=invalid-name,not-a-mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove these #pylint comments.

fortuna/model/model_manager/classification.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR - very clean implementation :D As with @gianlucadetommaso, most of my comments are minor ones.

A general comments is to make sure that any type signatures of jnp.ndarray are replaced by jax.Array. You may like to look into JaxTyping where the array's shape can dynamically be specified. This will allow for nice type checking in the future.

docs/source/references/model/utils.rst Outdated Show resolved Hide resolved
examples/two_moons_classification_sngp.pct.py Outdated Show resolved Hide resolved
examples/two_moons_classification_sngp.pct.py Outdated Show resolved Hide resolved
examples/two_moons_classification_sngp.pct.py Outdated Show resolved Hide resolved
# Fortuna helps you converting data and data loaders into a data loader that Fortuna can digest.

# %%
from fortuna.data import DataLoader
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment that all imports should be at the top of the file to adhere to PEP rules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but is this true also for the example notebooks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on me. I'm using imports in each cell of example notebooks to make it easier to understand where objects are called from. I guess it is debatable what's best practice in this case.


def __call__(
self, x: Array, train: bool = False, **kwargs
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing on ndarray.

fortuna/model/utils/random_features.py Outdated Show resolved Hide resolved
fortuna/model/utils/random_features.py Outdated Show resolved Hide resolved
kernel_init: Callable[[PRNGKeyArray, Shape, Type], Array] = default_rbf_kernel_init
bias_init: Callable[[PRNGKeyArray, Shape, Type], Array] = default_rbf_bias_init
seed: int = 0
dtype: Type = jnp.float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will those pose any problems when we compute Cholesky factors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point. I do not know if this will work as expected in half precision for example. How do you suggest to proceed? Should we force the execution of this part in jnp.float32?

fortuna/model/utils/random_features.py Outdated Show resolved Hide resolved
- Moved SNGP callback to fortuna.posterior.sngp
- Improved SNGP docs
- Moved ClassificationModelManagers to own py file
- Typos in two_moons_classification_sngp.pct.py
- When calling the train method for a :class:`~fortuna.prob_model.base.ProbModel` instance, add a list of callbacks containing the ones previously defined when initializing :class:`~fortuna.prob_model.fit_config.base.FitConfig`.

The following example outlines the usage of :class:`~fortuna.prob_model.callbacks.base.Callback`.
It assumes that the user already obtained an insatnce of :class:`~fortuna.prob_model.base.ProbModel`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insatnce -> instance

- dryrun tests for sngp
- added possibility to do a sanity check of the posterior state in the posterior fit method
- for SNGP check that the deep feature extractor has spectral norm
@gianlucadetommaso gianlucadetommaso marked this pull request as ready for review April 25, 2023 13:47
@gianlucadetommaso gianlucadetommaso merged commit 84db207 into awslabs:main Apr 25, 2023
@albertogaspar albertogaspar deleted the sngp branch May 18, 2023 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants