Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple predictive from the Laplace class? #8

Open
wiseodd opened this issue May 17, 2021 · 0 comments
Open

Decouple predictive from the Laplace class? #8

wiseodd opened this issue May 17, 2021 · 0 comments
Labels
question Further information is requested

Comments

@wiseodd
Copy link
Collaborator

wiseodd commented May 17, 2021

Useful for: Users who want to implement custom predictive approximations.

Issue: Currently, the predictive approximation is tightly coupled with the Laplace class. So, if the user wanted to implement a new predictive approximation, they have to dig deep into this class, and it might break something not to mention that it can be confusing.

Proposal:

  • 2-steps predictive interface (function output and link predictives)
class FunctionPredictive:

    def __init__(self, ...):
        ...

    def __call__(self, x):
        ''' Return 2 arrays for means and vars '''
        raise NotImplementedError()


class LinearizedPredictive(FunctionPredictive):

    def __init__(self, laplace_net, ...):
        self.laplace_net = laplace_net
        ...

    def __call__(self, x):
        J = compute_jacobian(laplace_net, x)
        return laplace_net.map_prediction(x), J.T @ laplace_net.covmat @ J


class LinkPredictive:

    def __init__(self, ...):
        ...

    def __call__(self, f_mean, f_var):
        raise NotImplementedError()


class ProbitPredictive(LinkPredictive):

    def __init__(self, ...):
        ...

    def __call__(self, f_mean, f_var):
        return torch.sigmoid(f_mean / torch.sqrt(1 + pi/8 * f_var))
  • Usage
linearized_pred = LinearizedPred()
probit_pred = ProbitPred()  # Set it to `None` if one does regression
laplace_net = Laplace(..., function_predictive=linearized_pred, link_predictive=probit_pred)
laplace_net.fit(train_loader)
laplace_net(x)  # Prediction using the specified predictives
@aleximmer aleximmer added the question Further information is requested label Nov 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants