Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestions for documentation and features #94

Open
John-Boik opened this issue Apr 2, 2023 · 11 comments
Open

Suggestions for documentation and features #94

John-Boik opened this issue Apr 2, 2023 · 11 comments
Assignees
Labels
documentation Improvements or additions to documentation enhancement New feature or request good first issue Good for newcomers

Comments

@John-Boik
Copy link

I have done a (reasonably) deep dive into learning RxInfer over the past several weeks and would like to offer some suggestions regarding documentation and features. I write this both from my recent experience and from my guess as to how some other practitioners might wish to use (and learn) RxInfer, and where they might get stuck. I imagine that you (the RxInfer team) want the package to be used by a wide audience. Some material below is in reference to my post on discourse and the associated request to open an issue here. That post is about documentation for the score function and @average_energy macro.

First, some kudos are in order. RxInfer is quite impressive and it represents a large and important step forward from ForneyLab (which I have slight acquaintance with). I understand that RxInfer is a work in progress and that you likely have a variety of short- and long-term plans for documentation and features. So if I don't offer anything new here, please take this post as confirmation of your current direction. This post is lengthy, and if I'm guided to split it into certain portions for separate issues, I'm happy to do so. And I'm happy to contribute in some way towards the suggestions below, but I'm not sure what would be most helpful.

Big Picture

I suspect that you might be aiming to eventually implement a large set of nodes, associated rules, and approximation methods so that users can easily build models by combining building blocks via short snippets of code. If so, this would of course be useful. But in the meantime, and perhaps always, some users (like myself) will be interested in relying heavily on custom nodes. Interesting real-world problems might tend to require custom nodes, and custom nodes provide a means to tinker and learn about a given model, about RxInfer, and about new methods and approaches. You already allow for the creation of custom nodes, but I'm suggesting that creating custom nodes might actually be the bread and butter for some users. In my learning curve, almost all of my time was spent getting custom nodes and associated rules to work correctly.

I understand that part of the upcoming documentation will include more details on creating custom nodes. I'd like to suggest what might be additions to what you have planned. While the examples you now offer (e.g., mountain car problem) are useful, and more would be helpful, what would have been most useful to me in learning RxInfer is a large set of simple, self-contained custom nodes, complete with message passing rules, marginal rules, and score/free energy functions. By simple, I mean a series of simple linear and nonlinear operators, such as [+, -, *, /, log(x), x^2] combined with a series of approximation methods, such as: [no approximation, linearization, CVI, importance, RTS]. Some nodes would involve single operators and others (composite nodes) would involve multiple operators, for example: [+ and * and log and CVI]. Another good example would be a node that returns a Boolean variable for p(x>z) < eta, where x is a random variable and z and eta are scalars. Some custom nodes would be deterministic, and some stochastic. The code for each node would exist in a single file. It would not need its own notebook page or lengthy introduction (but liberal code comments are always helpful). Each node would probably need a simple (bare minimum) model and inference implementation, so that one can check that the node actually works and returns a free energy value, etc. On that note, some implementations could use the inference() function, while others use the update!() function, just for variety.

Obviously, RxInfer already includes code for nodes like + and matrix multiplication. For these and others, a user can figure things out by digging through your code base. My suggestion is to have, for each custom node example, one self-contained file that includes a complete implementation, cut down to the bare minimum, simplest code, including all rules, score methods (for free energy calculation), and so on. To make things really simple, and eliminate most multiple dispatch (as a user might do when prototyping for a specific problem), each node can be tailored for specific set of message types (Gaussian, for example). An exception here might be use of variable families (e.g., UnivariateNormalDistributionsFamily) in rule signatures, which can cut down on the number of required rules.

I'm not suggesting that someone would use these nodes in their model, but rather they would serve as examples of how to piece together different building blocks, quick and dirty, to build (increasingly complex) custom nodes. Importantly, they would also serve as examples of how to get the syntax right so that the code runs. I'm suggesting that all the pieces, including calls to approximation methods, reside within the rules of a custom node, not called via a meta specification to code elsewhere. As a user, for example, I might want to try out both linearization and importance sampling in the rule of a node, or implement my own approximation method, and halt the code within the rule to examine function outputs. Efficiency and code elegance are not the goals here.

Specifics

Below I provide a list of specific questions, comments, and observations. I'm not necessarily looking for answers to the questions, but offer them as examples of the kinds of questions I have had and that new users might have. Perhaps you could address them in the documentation.

  1. The use of tiny/huge sometimes makes RxInfer code choke. I've needed to convert tiny/huge to floats (that I call Tiny/Huge) to prevent this.
  2. Some function calls in the examples do not seem to work. For example, Gamma(shape=x,rate=x) in the Gamma mixture example throws errors.
  3. When would someone use the inference() function vs. calls to update!(). Any specific examples? It seems that with the new callbacks, you can do anything with the inference() function that you can with the update!() function.
  4. Why do examples seem to use calls like x_prior ~ Normal(...) followed by x0 = x_prior when setting an initial prior? Why can't this occur in one call, x0 ~ Normal(...), where x0 is then updated in every step of a for loop?
  5. How do you define stochastic vs. deterministic for defining node types? For example, if I use importance sampling for passing messages through a custom + node, is the node stochastic or deterministic? Internally, how are the two cases handled differently?
  6. Is it possible to pass messages down a graph branch and back via only custom rules that call custom nodes that have rules that call custom nodes and so on, or just rules that call rules that call rules?
  7. Is it possible to call an entirely new model from within a custom node rule?
  8. It seems that one can pass data to a model using the call signature of the model (e.g., @model my_model(; mu) or via use of datavars in the model (and associated use of the data parameter in the inference() function). For example, if a model contains the code x ~ Normal(mu, 2.), one can send mu to the model either way. Under what circumstances must some data be sent via datavars and others be sent via the signature call?
  9. It would be good to have a list of approximation methods in the documentation (with brief descriptions). Currently, the list is found by looking at the approximations folder.
  10. In the custom node examples, it would be good to have one or more nodes that:
    -- Require initialization of marginals, and some nodes that require initialization of marginals and messages.
    -- Have in the model some complex if-then conditional statements. This could be paired with the p(x>z) <eta custom node I suggested above.
    -- Implement a distribution from Distributions.jl (e.g., Poisson)
    -- Have a rule that calls other rules and/or nodes and/or models.

Finally, it would be good to include a full example that is described by narrative text. Consider the following, for example.

Jack and Jill (or Jack and Jake, or Jane and Jill) were recently introduced and both wish to meet tonight for dinner at a restaurant, on a planned date. Jill needs to run an errand first and depending on how long that takes she might not be able to meet Jack. She is not in control of the errand timing and will only know when she will finish about 5-30 minutes before she actually does finish. Once finished, it will take her roughly 30 minutes to travel to the restaurant. She will try to call jack if it looks like she can't make the date. But she is afraid that if she calls too late, he will be irritated and not want to meet her again. She is willing to accept a 5 percent chance of calling too late.

Jack must run two consecutive errands before meeting Jill. For each, he has a reasonable idea of how long they might take. And he has some control over the events. He can make each finish earlier (by tightening the variance of the event or shifting its mean), if it looks like its taking too long. His travel time after the errands is also roughly 30 minutes. He plans to call Jill if he has to cancel, but does not want to call too late. He too is concerned that if he calls too late, Jill will be irritated and not want to meet him again.

If they meet for dinner, there is a better than average chance that they will enjoy each other's company and want to meet again. But if one person is later than the other, the chance goes down. If both are equally late, there is no penalty. If either is more than 5 minutes late but does not call first, the date will go poorly and there will not be another. If either calls but is not late, that too will reduce the chance of a future date. Each time a person calls, there is another penalty.

The questions are, do they meet that night for dinner? Does anyone call? Do they meet on a later night for dinner? Does Jack have to alter his errands?

The specific probabilities of event occurrences are important for the outcome, but not important for the concept. They can be anything reasonable. Some interesting aspects of the story are that both characters undergo an errand process independent of the other. Those processes involve chance constraints and optimal control, and for Jack they occur twice. And the story line involves several Boolean decision variables. Plus, it stems from a narrative rather than mathematical form. The story could be set up within one model. But could it also be formed as, say, three separate models? Could each errand branch be processed as a sequence of rules within rules? How many ways could this problem be conceived? (Related, how many ways can branch and recursive programming be implemented in RxInfer?)

The narrative aspect is important here. The idea is that RxInfer might be suitable for modeling a person's belief system, as that person is depicted in a story. As such, there is a meta process of: Narrative --> Factor graph --> Simulation --> Results. The example would serve as a warm up.

@albertpod albertpod self-assigned this Apr 3, 2023
@bvdmitri
Copy link
Member

bvdmitri commented Apr 3, 2023

Hey @John-Boik ! Many thanks! Your feedback is invaluable. It is important for us to get such descriptive comments from someone outside of our lab. We indeed think of the custom nodes and rules as one of the selling points for RxInfer and its a shame that we don't have a descriptive and user-friendly tutorial about it (yet). We are going to discuss your feedback and your suggestions on one of our internal meetings and will try to address it as soon as we can.

@albertpod albertpod added documentation Improvements or additions to documentation enhancement New feature or request good first issue Good for newcomers labels Apr 3, 2023
@John-Boik
Copy link
Author

Related to my previous post, and perhaps a topic deserving of documentation, suppose that I have a custom node for a function that looks like z = f(x, y). In the model statement, it might look something like z ~ MyNode(x, y). I wish to experiment with rules for x (i.e., @rule MyNode(:x, Marginalisation). How would I obtain q_ins within the rule body, if desired? The q_ins are used in Delta node rules, but how can they be obtained for the rules of custom nodes?

On a related topic, perhaps deserving of additional documentation, in test_in.jl for the delta node, unscented, the function h(x,y) has two inputs. For the "Multiple input with unknown inverse" test, the calling signature is:

DeltaFn{h}((:in, k = 2)

And inputs are defined as:

q_ins = JointNormal(MvNormalMeanCovariance(ones(3), diageye(3)), ((), (), ())), 
m_in = NormalMeanVariance(0.0, 10.0), 
meta = DeltaMeta(; method = Unscented())

Assume this rule is for the backward message x. Then somehow, the above signature must contain messages m_y and m_out, and also the marginals for y and out. But if so, I don't understand how.

What is k?

What do the two inputs to JointNormal refer to? I assume the first refers to q_in for out, x, and y, in that order. Correct? What does the tuple of empty tuples refer to? Do the q_ins refer to marginals or messages?

In the next test, the tuple of tuples is ((1,), (2,), ())). Other tests use ((0,), (0,)) and ((1,), (1,)). Again, what do these refer to?

In short, how does the signature contain messages m_y and m_out and the marginals for y and out?

While these topics might be good to address in documentation, I'm also hoping to learn the answers here.

@bvdmitri
Copy link
Member

bvdmitri commented Apr 12, 2023

Hey @John-Boik, thank you for your questions! I'll do my best to provide clear answers.

How would I obtain q_ins within the rule body, if desired?

Regarding obtaining q_ins within the rule body, you can compute it manually. Delta nodes have special behavior and do not follow the default arguments specification for performance reasons. You can refer to this code example here. If you wish to access q_ins in your rule, you can incorporate similar code that computes q_ins within your rule.

Then somehow, the above signature must contain messages m_y and m_out, and also the marginals for y and out. But if so, I don't understand how.
What is k?

Our inference backed instead of x, y, etc... uses in_1, in_2, etc.. but I will use x, y notation to explain this example.
You are right by saying that technically we need m_out and m_y to compute m_x, but it is not strictly necessary, because the backward message on x can be computed in a different way, which happens to be more efficient in our case. You can compute backward message on x by marginalizing out q(x, y) and dividing the result q(x) on forward message on x, such that the backward message is equal to

equation

This may work, because of
image

image

This approach may work because of the specific assumptions and characteristics of the unscented method with Gaussian distributions. In delta node rules, we precompute q_ins and then marginalize all arguments except one and divide to compute backward messages. The k = 2 refers to the "index" of the edge for which we are computing the message, e.g., in_2. Here we get to your last question:

What do the two inputs to JointNormal refer to?

JointNormal is a structure that allows fast marginalisation routine. It represents q(in_1, in_2, in_3, ...) and tuples ((), (), ()) etc represent dimensionality of each component of the joint distribution.

  • () refers to univariate
  • (n, ) refers to multivariate of size n

In general, a user should never interact with the JointNormal structure and we never returned as a result of the inference. It is simply an internal data structure for handling joint distribution of normally distributed variables.

In short, how does the signature contain messages m_y and m_out and the marginals for y and out?

It does not. It contains a joint marginal over in_1 and in_2, such that we can marginalize and divide.

@albertpod
Copy link
Member

Hi @John-Boik! Another follow-up on your suggestions.
We have created a few associated issues: #97, #98, #99.

As for specifics:

  1. We will be improving the handling of tiny/huge to prevent related issues, starting with Propagate log of Categorical distribution parameters? ReactiveMP.jl#75.
  2. We will be updating the examples, including the Gamma mixture example, to fix any errors.
  3. We have noted your suggestion regarding the inference() function and will consider it for future improvements.
  4. The refactoring of GraphPPL will streamline the process of setting initial priors.
  5. We will clarify the distinction between stochastic and deterministic node types in our documentation: Revise node-creation section in docs #99.
  6. We will address these issues by introducing nested model specification (GraphPPL refactor) @wouterwln.
  7. ~6
  8. We will improve data passing in the model, addressing the use of datavars and signature calls.
  9. We are currently working on expanding our documentation, including a list of approximation methods with brief descriptions. The approximations will be moved to a separate package.
  10. We will enhance our custom node examples with the suggested features, such as initialization and complex if-then statements.

Thank you once again for your valuable feedback, and we look forward to making these improvements.

@John-Boik
Copy link
Author

Thanks @bvdmitri for answering my questions. That helped quite a bit. But I think there is still a bit more to the story. I did not understand until I read your reply that the q_ins marginals sent to @rule DeltaFn((:in, k), Marginalisation) were already updated, meaning that they already take into account the backward message out_bw, for the node's out variable. So sure, it makes sense to calculate message x1_bw by dividing the marginal of x1 by the x1_fwd message.

But, including all this in a custom node might not be as simple as you suggest, if I understood you correctly. You had said that the q_ins could be calculated using code similar to that of @marginalrule DeltaFn, but I don't think that will work without something additional.

Consider a custom node for the simple function x3 = x1 + x2. RXInfer usually uses out, but here I use x3 as a label. We can represent the function as: fx_add(x1, x2) = x1 + x2 and the model would include a statement like: x3 ~ CustomNode(x1, x2). Three rules are required for this custom node, one for each edge. The input to each rule is two variables. For example, inputs for the forward rule are m_x1 and m_x2, which I will write as messages x1_fw and x2_fw.

In the case of a delta node, the model statement could look like:

x3 ~ fx_add(x1, x2) where { meta = DeltaMeta(method = Unscented()) }

For the delta node, the order of computation is:

  1. The @marginalrule DeltaFn is called with input messages x3_bw and x1_fw and x2_fw.
  2. In the marginal rule, unscented_statistics is called with inputs fx_add and messages x1_fw, and x2_fw. It returns the x3_fw message and the covariance C between the x1:x3 and x2:x3 forward messages.
  3. RTS is called with inputs C and messages x3_fw, x1_fw and x2_fw, and x3_bw. It returns the marginals for x1 and x2.
  4. The marginals x1 and x2 are sent to @rule DeltaFn((:in, k), Marginalisation), where messages x1_bw or x2_bw are calculated by division, as already noted.

It appears that a simple application of code similar to that of @marginalrule DeltaFn within a custom node rule will not work. First, there are only two inputs to a custom node rule. In particular, for the forward rule, x3_bw is not an input. So items 1 and 3 above can't be mimicked. Second, even if the x3_bw message was available to the forward custom node rule, the marginals calculated in item 3 cannot be sent to the backward rules for x1 and x2.

So, is there a way to calculate q_ins within the rules of custom nodes? I suppose that one way might be to hijack the delta node machinery (including layouts, etc.) somehow creating a near-replica to handle a custom node. But that sounds complicated. A much easier solution would be to have a getmessage function that is callable from the forward rule that would allow one to obtain x3_bw, as in my example. Perhaps such a function already exists, but if so I've not been able to find it. The function getmarginal exists, but I don't know how to use it within a rule to obtain a marginal. The syntax getmarginal(m_x1) clearly does not work. Similarly, one would need a setmarginal! function, so that one could calculate a marginal in the forward rule, save it, and then recall it for later use in a backward rule. Again, this function exists, but I don't know if it can do what I need within a rule. And one would also need to take into account variable indexes for getmarginal, setmarginal!, and getmessage, in case you needed to get or set the marginal or message for, say, x3[k].

Another option would be to send x3_bw in a pipeline from the call to CustomNode in the model. Is something like this possible? But once marginals are calculated in the forward rule, one would still need to call setmarginal! in that rule and getmarginal in the rules for x1 and x2.

Yet another option would be to include in the custom node structure items like out::NodeInterface and ins::NTuple{N, IndexedNodeInterface}, similar to what is done for delta nodes. Perhaps there is a simple way to do this that would achieve my aims. If so, can you offer a code snippit?

Or, maybe I have missed something and the answer is right in front of me. Any suggestions would be appreciated. Perhaps you are wondering why I would want to calculate q_ins within a custom node. After all, RxInfer already has efficient and elegant code (i.e., for delta nodes) that does what I want to duplicate. It is so that I can experiment with different approaches, and learn about RxInfer in the process. I would guess that other uses might also want to experiment in this way as they try out new (or even old) ideas. And I would like to have full control of, say, some complicated composite node.

@bvdmitri
Copy link
Member

It appears that a simple application of code similar to that of @marginalrule DeltaFn within a custom node rule will not work. First, there are only two inputs to a custom node rule. In particular, for the forward rule, x3_bw is not an input.

You're right, my bad. I missed that. But your next question helps:

A much easier solution would be to have a getmessage function that is callable from the forward rule that would allow one to obtain x3_bw, as in my example.
There is already a solution to that, which is briefly mentioned in one of the examples, but is not properly documented (we will work on that).

I assume you have the following node specification (or something similar)

function fx_add end

@node typeof(fx_add) Deterministic [ out, x1, x2 ]

So if you simply want to get a message on the same edge from the forward rule you simply write it like:

x3 ~ fx_add(x1, x2) where { pipeline = RequireMessage(out = NormalMeanPrecision(0, 0.01))  }

This specification explicitly states, that you require inbound message on the out interface.

The same applies to other interfaces, so if you want messages on the same edges for all interfaces you can write

x3 ~ fx_add(x1, x2) where { 
    pipeline = RequireMessage(
        out = NormalMeanPrecision(0.0, 0.01),
        x1 = NormalMeanPrecision(0.0, 0.01),
        x2 = NormalMeanPrecision(0.0, 0.01),
    )  
}

The = is optional and indicates that you want to preinitialize the message with something, but you can also write it like:

x3 ~ fx_add(x1, x2) where { 
    pipeline = RequireMessage(out, x1, x2)  
}

which says that you want inbound messages on all edges when computing outbound messages, but without any explicit initialisation (will probably require initmessages = in the inference function).

Second, even if the x3_bw message was available to the forward custom node rule, the marginals calculated in item 3 cannot be sent to the backward rules for x1 and x2.

You are right and that is the reason why delta nodes have a custom layout for rules, but if you only want to experiment I would advise you to recompute the marginals in all rules, which is easier for now.

P.S. Note that there is currently a small bug in the model specification library with RequireMessage, but it is not affecting much.

P.S.S. This is the code I used to quickly test the pipeline specification

using RxInfer

function fx_add end

@node typeof(fx_add) Deterministic [ out, x1, x2 ]

@model function try_fx_add()
   y = datavar(Float64)
   # Avoid using interfaces names for the variables due
   # to the bug in GraphPPL, I put `_`
   x1_ ~ NormalMeanVariance(0.0, 1.0)
   x2_ ~ NormalMeanVariance(0.0, 1.0)
   out_ ~ fx_add(x1_, x2_) where { pipeline = RequireMessage(out, x1, x2) }
   y ~ NormalMeanVariance(out_, 1.0)
end

inference(
    model = try_fx_add(),
    data = (y = 1.0, )
)

and I get the expected error with a missing rule

RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule typeof(fx_add)(:x2, Marginalisation) (m_out::NormalMeanVariance, m_x1::NormalMeanVariance, m_x2::NormalMeanVariance, ) = begin 
    return ...
end

Note that x_2 interface requires message m_x2.

@John-Boik
Copy link
Author

Thanks @bvdmitri, that was very helpful. But I think a tweak is still needed. In the rules for calculating the backward messages, say, x1, the updated marginal for x1 is needed (updated based on information from out_fw and out_bw). That marginal cannot be calculated in the x1 rule. You could use x1_fw and x2_fw to calculate out_fw, but you don't have access there to out_bw.

I include code below for a complete working example. In it, I calculate the marginals for x1 and x2 in the rule for out, then save those marginals in a structure for later recall. Its not elegant, but it works as long as the x1 and x2 rules are called after the out rule. If you have a better idea, I would be happy to hear it. Of course, I could also initialize the q_ins of the structure with a real value, so that rules for x1 and x2 could be called before the rule for out. As an aside, I use HCubature in the code just for fun. I could have also used the cubature function supplied by RxInfer.

I do have a related question. Suppose that you wanted to use the delta method and have an inverse function for only one of the inputs, say, x1 in my example. How would you specify that you have only one of the inverses? I'm guessing something like DeltaMeta(method = Linearization(), inverse = (x1_inv, nothing)). Would that be correct? Here I am trying to specify that the inverse for x2 is not available.

And I have a comment. The search function for RXInfer documentation seems to have some issues. For example, search for the term "inv" and you will get a list of links to pages, but for some of those pages the term "inv" only occurs in the left navigation panel, not in the actual page on the right. I've noticed this for a variety of search terms and it makes it more difficult to find what you are looking for. Perhaps this is an issue that is beyond your control to fix. But if you can fix it, that would be helpful.

module RxInferAddExample


import Distributions
import Random
import Statistics
import HCubature
import ReactiveMP

using Formatting
using Infiltrator
using Revise
using RxInfer

Random.seed!(51233) # Set random seed for reproducibility
Huge = convert(Float64, huge)
Tiny = convert(Float64, tiny)


####################################################################################################
struct CustomAddNode end

mutable struct Q
    q_ins::Union{Nothing, JointNormal} 
end
Qins = Q(nothing)

@node CustomAddNode Deterministic [out, x1, x2]

pdf(d, u) = Distributions.pdf(d, u)


# --- add function ---------------------------------------------------------------------------------
function fx_add(x1, x2)
    return x1 + x2
end


# -- marginalrule for delta function from rules/delta/unscented/marginals.jl -----------------------
function apply_marginal_rule(m_x1, m_x2, m_out, fx)
    # Approximate joint inbounds
    m_ins = [m_x1, m_x2]
    statistics = mean_cov.(m_ins)
    μs_fw_in = first.(statistics)
    Σs_fw_in = last.(statistics)
    sizes = size.(m_ins)
    
    
    μs_fw_in = tuple(μs_fw_in...)
    Σs_fw_in = tuple(Σs_fw_in...)
    (μ_tilde, Σ_tilde, C_tilde) = ReactiveMP.unscented_statistics(Unscented(), Val(true), fx_add, μs_fw_in, Σs_fw_in)

    joint              = convert(JointNormal, μs_fw_in, Σs_fw_in)
    (μ_fw_in, Σ_fw_in) = mean_cov(joint)
    ds                 = ReactiveMP.dimensionalities(joint)

    # Apply the RTS smoother
    (μ_bw_out, Σ_bw_out) = mean_cov(m_out)
    (μ_in, Σ_in) = ReactiveMP.smoothRTS(μ_tilde, Σ_tilde, C_tilde, μ_fw_in, Σ_fw_in, μ_bw_out, Σ_bw_out)

    dist = convert(promote_variate_type(variate_form(μ_in), NormalMeanVariance), μ_in, Σ_in)
    
    #@infiltrate; @assert false   
    return dist
end


# --- out ------------------------------------------------------------------------------------------
@rule CustomAddNode(:out, Marginalisation) (
    m_out::UnivariateNormalDistributionsFamily,
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, 
    ) = begin 
    
    # use cubature for mean of forward message
    function f_mean(z) 
        x1 = z[1]  
        x2 = z[2]
        return fx_add(x1, x2) * pdf(m_x1, x1) * pdf(m_x2, x2) 
    end
    
    lo = -20
    hi = 20
    (mean_out, mean_out_err) = HCubature.hcubature(f_mean, (lo,lo), (hi,hi), initdiv=2, atol=1e-9)  
    
    # use cubature for variance of forward message
    function f_var(z) 
        x1 = z[1]  
        x2 = z[2]
        p = fx_add(x1, x2) * pdf(m_x1, x1) * pdf(m_x2, x2)
        p = p .* fx_add(x1, x2)
        return  p
    end
    
    (var_out, var_out_err) = HCubature.hcubature(f_mean, (lo,lo), (hi,hi), initdiv=2, atol=1e-9)  
    
    dist = apply_marginal_rule(m_x1, m_x2, m_out, fx_add)
    sizes = ((), ())
    Qins.q_ins = JointNormal(dist, sizes)
    
    #@infiltrate; @assert false    
    return NormalMeanVariance(mean_out, clamp(var_out, Tiny, Huge)) 
end


# --- x1 -------------------------------------------------------------------------------------------
@rule CustomAddNode(:x1, Marginalisation) (
    m_out::UnivariateNormalDistributionsFamily, 
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, 
    ) = begin 
    
    q_ins = Qins.q_ins
    k=1
    
    # Divide marginal on inx by forward message
    ξ_inx, Λ_inx       = weightedmean_precision(getmarginal(q_ins, k))
    ξ_fw_inx, Λ_fw_inx = weightedmean_precision(m_x1)

    ξ_bw_inx = ξ_inx - ξ_fw_inx
    Λ_bw_inx = Λ_inx - Λ_fw_inx # Note: subtraction might lead to posdef violations

    #@infiltrate; @assert false
    return convert(promote_variate_type(variate_form(ξ_inx), NormalWeightedMeanPrecision), ξ_bw_inx, Λ_bw_inx)
end 
    

# --- x2 -------------------------------------------------------------------------------------------
@rule CustomAddNode(:x2, Marginalisation) (
    m_out::UnivariateNormalDistributionsFamily, 
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, 
    ) = begin 
    
    q_ins = Qins.q_ins
    k=2
    
    # Divide marginal on inx by forward message
    ξ_inx, Λ_inx       = weightedmean_precision(getmarginal(q_ins, k))
    ξ_fw_inx, Λ_fw_inx = weightedmean_precision(m_x2)

    ξ_bw_inx = ξ_inx - ξ_fw_inx
    Λ_bw_inx = Λ_inx - Λ_fw_inx # Note: subtraction might lead to posdef violations

    #@infiltrate; @assert false
    return convert(promote_variate_type(variate_form(ξ_inx), NormalWeightedMeanPrecision), ξ_bw_inx, Λ_bw_inx)
end 


# --- marginal rule --------------------------------------------------------------------------------
@marginalrule CustomAddNode(:x1_x2) (
    m_out::UnivariateNormalDistributionsFamily, 
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, ) = begin 
    
    xi_3, W_3 = weightedmean_precision(m_out)
    xi_1, W_1 = weightedmean_precision(m_x1)
    xi_2, W_2 = weightedmean_precision(m_x2)
    
    #@infiltrate; @assert false
    return MvNormalWeightedMeanPrecision([xi_1 + xi_3; xi_2 + xi_3], [W_1 + W_3 W_3; W_3 W_2 + W_3])
end


# --------------------------------------------------------------------------------------------------
@model function rx_model(;N)
    
    # datavars
    x1_mean = datavar(Float64, N)
    x1_var = datavar(Float64, N)
    x2_mean = datavar(Float64, N)
    x2_var = datavar(Float64, N)
    out_mean = datavar(Float64, N)
    out_var = datavar(Float64, N)
    
    # random vars
    x1_ = randomvar(N)
    x2_ = randomvar(N)
    out_ = randomvar(N)
    
    # loop
    x0 = 0.
    for k in 1:1
        x1_[k] ~ NormalMeanVariance(x1_mean[k], x1_var[k])
        x2_[k] ~ NormalMeanVariance(x2_mean[k], x2_var[k])
                
        # custom node
        out_[k] ~ x0 + CustomAddNode(x1_[k], x2_[k])  where { pipeline = RequireMessage(out, x1, x2) }
        
        # delta method
        #out_[k] ~ x0 + fx_add(x1_[k], x2_[k]) where { meta = DeltaMeta(method = Unscented()) }
        
        # regular addtion
        #out_[k] ~ x0 + x1_[k] + x2_[k]
        
        out_[k] ~ NormalMeanVariance(out_mean[k], out_var[k]) 
        x0 = out_[k]  
    end    
    return (x1_,x2_,out_)
end


# --------------------------------------------------------------------------------------------------
function main()
    result = inference(
        data = (
            x1_mean = [2.],
            x1_var = [4.],
            x2_mean = [4.],
            x2_var = [2.],
            out_mean = [0.],
            out_var = [Huge],
        ),
        model = rx_model(N=1,),
        
        initmessages = (
            x1_ = NormalMeanVariance(0.0, 10.),
            x2_ = NormalMeanVariance(0.0, 10.),
            out_ = NormalMeanVariance(0.0, 10.),
        ),
        
        returnvars = (
            x1_ = KeepLast(),
            x2_ = KeepLast(),
            out_ = KeepLast(),
        ),
        free_energy = true,
        iterations = 2,
    )
    
    out = mean_var.(result.posteriors[:out_])
        
    printfmtln("\n=========\nout out = {}\n", out)  # expect Normal(6., 6.)
    printfmtln("free energy= {}\n", result.free_energy)
    @infiltrate; @assert false
end

end  #  module -----------------------------

RxInferAddExample.main()

@John-Boik
Copy link
Author

My bad, @bvdmitri, I misunderstood the meaning of a pipeline with multiple messages, and my tweak was not necessary (although it does save a little computation). A working version using a pipeline with multiple messages is below.

For your convenience, I repeat my question here. Suppose that you wanted to use a delta node and have an inverse function for only one of the inputs, say, x1 in my example. How would you specify that you have only one of the inverses? I'm guessing something like DeltaMeta(method = Linearization(), inverse = (x1_inv, nothing)). Would that be correct?

Thanks again for your help.

module RxInferAddExample

import Distributions
import Random
import Statistics
import HCubature
import ReactiveMP

using Formatting
using Infiltrator
using Revise
using RxInfer

Random.seed!(51233) # Set random seed for reproducibility
Huge = convert(Float64, huge)
Tiny = convert(Float64, tiny)


####################################################################################################
struct CustomAddNode end

@node CustomAddNode Deterministic [out, x1, x2]

pdf(d, u) = Distributions.pdf(d, u)


# --- add function ---------------------------------------------------------------------------------
function fx_add(x1, x2)
    return x1 + x2
end


# -- marginalrule for delta function from rules/delta/unscented/marginals.jl -----------------------
function apply_marginal_rule(m_x1, m_x2, m_out, fx)
    # Approximate joint inbounds
    
    m_ins = [m_x1, m_x2]
    statistics = mean_cov.(m_ins)
    μs_fw_in = first.(statistics)
    Σs_fw_in = last.(statistics)
    sizes = size.(m_ins)
    
    
    μs_fw_in = tuple(μs_fw_in...)
    Σs_fw_in = tuple(Σs_fw_in...)
    (μ_tilde, Σ_tilde, C_tilde) = ReactiveMP.unscented_statistics(Unscented(), Val(true), fx_add, μs_fw_in, Σs_fw_in)

    joint              = convert(JointNormal, μs_fw_in, Σs_fw_in)
    (μ_fw_in, Σ_fw_in) = mean_cov(joint)
    ds                 = ReactiveMP.dimensionalities(joint)

    # Apply the RTS smoother
    (μ_bw_out, Σ_bw_out) = mean_cov(m_out)
    (μ_in, Σ_in) = ReactiveMP.smoothRTS(μ_tilde, Σ_tilde, C_tilde, μ_fw_in, Σ_fw_in, μ_bw_out, Σ_bw_out)

    dist = convert(promote_variate_type(variate_form(μ_in), NormalMeanVariance), μ_in, Σ_in)
    
    #@infiltrate; @assert false   
    return dist
end


# --- out ------------------------------------------------------------------------------------------
@rule CustomAddNode(:out, Marginalisation) (
    m_out::UnivariateNormalDistributionsFamily,
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, 
    ) = begin 
    
    # use cubature for mean of forward message
    function f_mean(z) 
        x1 = z[1]  
        x2 = z[2]
        return fx_add(x1, x2) * pdf(m_x1, x1) * pdf(m_x2, x2) 
    end
    
    lo = -20
    hi = 20
    (mean_out, mean_out_err) = HCubature.hcubature(f_mean, (lo,lo), (hi,hi), initdiv=2, atol=1e-9)  
    
    # use cubature for variance of forward message
    function f_var(z) 
        x1 = z[1]  
        x2 = z[2]
        p = fx_add(x1, x2) * pdf(m_x1, x1) * pdf(m_x2, x2)
        p = p .* fx_add(x1, x2)
        return  p
    end
    
    (var_out, var_out_err) = HCubature.hcubature(f_mean, (lo,lo), (hi,hi), initdiv=2, atol=1e-9)  
        
    #@infiltrate; @assert false    
    return NormalMeanVariance(mean_out, clamp(var_out, Tiny, Huge)) 
end


# --- x1 -------------------------------------------------------------------------------------------
@rule CustomAddNode(:x1, Marginalisation) (
    m_out::UnivariateNormalDistributionsFamily, 
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, 
    ) = begin 
    
    out_fwd = @call_rule CustomAddNode(:out, Marginalisation) (m_out=m_out, m_x1=m_x1, m_x2=m_x2)
    dist = apply_marginal_rule(m_x1, m_x2, m_out, fx_add)
    sizes = ((), ())
    q_ins = JointNormal(dist, sizes)
        
    # Divide marginal on inx by forward message
    k=1
    ξ_inx, Λ_inx       = weightedmean_precision(getmarginal(q_ins, k))
    ξ_fw_inx, Λ_fw_inx = weightedmean_precision(m_x1)

    ξ_bw_inx = ξ_inx - ξ_fw_inx
    Λ_bw_inx = Λ_inx - Λ_fw_inx # Note: subtraction might lead to posdef violations

    #@infiltrate; @assert false
    return convert(promote_variate_type(variate_form(ξ_inx), NormalWeightedMeanPrecision), ξ_bw_inx, Λ_bw_inx)
end 
    

# --- x2 -------------------------------------------------------------------------------------------
@rule CustomAddNode(:x2, Marginalisation) (
    m_out::UnivariateNormalDistributionsFamily, 
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, 
    ) = begin 
    
    out_fwd = @call_rule CustomAddNode(:out, Marginalisation) (m_out=m_out, m_x1=m_x1, m_x2=m_x2)
    dist = apply_marginal_rule(m_x1, m_x2, m_out, fx_add)
    sizes = ((), ())
    q_ins = JointNormal(dist, sizes)
        
    # Divide marginal on inx by forward message
    k=2
    ξ_inx, Λ_inx       = weightedmean_precision(getmarginal(q_ins, k))
    ξ_fw_inx, Λ_fw_inx = weightedmean_precision(m_x2)

    ξ_bw_inx = ξ_inx - ξ_fw_inx
    Λ_bw_inx = Λ_inx - Λ_fw_inx # Note: subtraction might lead to posdef violations

    #@infiltrate; @assert false
    return convert(promote_variate_type(variate_form(ξ_inx), NormalWeightedMeanPrecision), ξ_bw_inx, Λ_bw_inx)
end 


# --- marginal rule --------------------------------------------------------------------------------
@marginalrule CustomAddNode(:x1_x2) (
    m_out::UnivariateNormalDistributionsFamily, 
    m_x1::UnivariateNormalDistributionsFamily, 
    m_x2::UnivariateNormalDistributionsFamily, ) = begin 
    
    xi_3, W_3 = weightedmean_precision(m_out)
    xi_1, W_1 = weightedmean_precision(m_x1)
    xi_2, W_2 = weightedmean_precision(m_x2)
    
    #@infiltrate; @assert false
    return MvNormalWeightedMeanPrecision([xi_1 + xi_3; xi_2 + xi_3], [W_1 + W_3 W_3; W_3 W_2 + W_3])
end


# --------------------------------------------------------------------------------------------------
@model function rx_model(;N)
    
    # datavars
    x1_mean = datavar(Float64, N)
    x1_var = datavar(Float64, N)
    x2_mean = datavar(Float64, N)
    x2_var = datavar(Float64, N)
    out_mean = datavar(Float64, N)
    out_var = datavar(Float64, N)
    
    # random vars
    x1_ = randomvar(N)
    x2_ = randomvar(N)
    out_ = randomvar(N)
    
    # loop
    x0 = 0.
    for k in 1:1
        x1_[k] ~ NormalMeanVariance(x1_mean[k], x1_var[k])
        x2_[k] ~ NormalMeanVariance(x2_mean[k], x2_var[k])
                
        # custom node
        out_[k] ~ x0 + CustomAddNode(x1_[k], x2_[k])  where { 
            pipeline = RequireMessage(out, x1, x2) 
        }
        
        # delta method
        #out_[k] ~ x0 + fx_add(x1_[k], x2_[k]) where { meta = DeltaMeta(method = Unscented()) }
        
        # regular addtion
        #out_[k] ~ x0 + x1_[k] + x2_[k]
        
        out_[k] ~ NormalMeanVariance(out_mean[k], out_var[k]) 
        x0 = out_[k]  
    end    
    return (x1_,x2_,out_)
end


# --------------------------------------------------------------------------------------------------
function main()
    result = inference(
        data = (
            x1_mean = [2.],
            x1_var = [4.],
            x2_mean = [4.],
            x2_var = [2.],
            out_mean = [0.],
            out_var = [Huge],
        ),
        model = rx_model(N=1,),
        
        initmessages = (
            x1_ = NormalMeanVariance(0.0, 10.),
            x2_ = NormalMeanVariance(0.0, 10.),
            out_ = NormalMeanVariance(0.0, 10.),
        ),
        
        returnvars = (
            x1_ = KeepLast(),
            x2_ = KeepLast(),
            out_ = KeepLast(),
        ),
        free_energy = true,
        iterations = 2,
    )
    
    out = mean_var.(result.posteriors[:out_])
        
    printfmtln("\n=========\nout out = {}\n", out)  # expect Normal(6., 6.)
    printfmtln("free energy= {}\n", result.free_energy)
    @infiltrate; @assert false
end

end  #  module -----------------------------

RxInferAddExample.main()

@bvdmitri
Copy link
Member

How would you specify that you have only one of the inverses? I'm guessing something like DeltaMeta(method = Linearization(), inverse = (x1_inv, nothing)). Would that be correct?

Hey @John-Boik, sorry for the later reply, I was on my vacation. Specifying only one of the inverses is not implemented (yet), though the way you guessed it is exactly how we have planned to implement this functionality.

@bvdmitri
Copy link
Member

Initial draft of the custom nodes and rules is available in the documentation: https://biaslab.github.io/RxInfer.jl/stable/manuals/custom-node/

@bvdmitri
Copy link
Member

bvdmitri commented Oct 5, 2023

@mhidalgoaraya

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request good first issue Good for newcomers
Projects
Status: 🤔 Ideas
Development

No branches or pull requests

4 participants