Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I extract the ELBO? #101

Open
Ananyapam7 opened this issue Aug 20, 2023 · 7 comments
Open

How do I extract the ELBO? #101

Ananyapam7 opened this issue Aug 20, 2023 · 7 comments

Comments

@Ananyapam7
Copy link

Ananyapam7 commented Aug 20, 2023

Hi, I am using this software for my Master's thesis and I want to extract the ELBO (evidence lower bound) and maybe also track it as the training happens. How can do that? For instance I have the snippet:

data(barents)
barents$Offset <- NULL
barents <- prepare_data(barents$Abundance, barents$Covariate, offset = "none")
barents_PLN <- PLN(Abundance ~ 1, data = barents)
parameters <- barents_PLN$model_par
sigma <- parameters$Sigma
mu <- parameters$Theta

I wish to visualize the elbo during the training in the line when we call PLN. How can I achieve this? I have looked at the codebase and understood that the function torch_elbo is the function used to compute the elbo. If I naively try to insert a cat after extracting the loss in the PLNfit-class.R file. For context, my modification looks like this and I was expecting to see the values of the elbo after this, but I was not able to:

permute <- torch::torch_randperm(self$n) + 1L
        for (batch_idx in 1:num_batch) {
          # here index is a vector of the indices in the batch
          index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)]

          ## Optimization
          optimizer$zero_grad() # reinitialize gradients
          loss <- private$torch_elbo(data, params, index) # compute current ELBO
          cat("loss:", loss)
          loss$backward()                   # backward propagation
          optimizer$step()                  # optimization
        }

Any help would be really appreciated

@mahendra-mariadassou
Copy link
Collaborator

Hi,

Thank you for your interest in PLNmodels. There are actually two backends used for the optimization part: nlopt (default) and torch. With the code you used, you relied on nlopt which calls C++ code. There is no easy way to extract the trajectory of the ELBO without digging into the source files and recompiling the package.

You can instead rely instead on the torch backend (as you did) but you need to tell PLN to use it and then use the provided extractors. Unfortunately initialization is a bit tricky for torch and the objective may diverge depending on the starting point as is the case here.

library(PLNmodels)
#> This is packages 'PLNmodels' version 1.0.3-9020
#> Use future::plan(multicore/multisession) to speed up PLNPCA/PLNmixture/stability_selection.
data(barents)

## Try running PLN with torch backend: fails
barents_torch <- PLN(Abundance ~ 1, data = barents, control = PLN_param(backend = "torch"))
#> 
#>  Initialization...
#>  Adjusting a full covariance PLN model with torch optimizer
#> Error in if (delta_f < config$ftol_rel) status <- 3: valeur manquante là où TRUE / FALSE est requis

To circumvent the problem, you can provide a "good" starting point (like the one found by the nlopt backend).

## Initialize PLN with nlopt
barents_nlopt <- PLN(Abundance ~ 1, data = barents)
#> 
#>  Initialization...
#>  Adjusting a full covariance PLN model with nlopt optimizer
#>  Post-treatments...
#>  DONE!
## Optimize PLN with torch starting with the nlopt solution. 
barents_torch <- PLN(Abundance ~ 1, data = barents, control = PLN_param(backend = "torch", inception = barents_nlopt))
#> 
#>  Initialization...
#>  Adjusting a full covariance PLN model with torch optimizer
#>  Post-treatments...
#>  DONE!
## Extract ELBO values
barents_torch$optim_par$objective
#>  [1] -294112.1 -293098.6 -292930.9 -293720.3 -293989.9 -294007.6 -294031.3
#>  [8] -294057.8 -294090.8 -294080.7 -294088.0 -294107.1 -294104.7 -294109.1
#> [15] -294113.4 -294111.4 -294112.7 -294114.8 -294114.9 -294115.2 -294115.7
#> [22] -294115.4 -294115.5 -294115.9 -294116.0 -294116.1 -294116.2 -294116.3
#> [29] -294116.3 -294116.4 -294116.5 -294116.6 -294116.6 -294116.7 -294116.8
#> [36] -294116.8 -294116.9 -294116.9 -294117.0 -294117.0 -294117.1 -294117.2
#> [43] -294117.2 -294117.2 -294117.2 -294117.2

Created on 2023-08-20 with reprex v2.0.2

In that case, the ELBO doesn't change much since the starting point is already close to stationary point of the ELBO but it hopefully works better on your examples.

@mahendra-mariadassou
Copy link
Collaborator

mahendra-mariadassou commented Aug 21, 2023

You can in fact completely avoid the inception by choosing a smaller learning rate for the torch backend to solve the convergence problem, leading to a simpler and more elegant solution.

library(PLNmodels)
#> This is packages 'PLNmodels' version 1.0.3-9020
#> Use future::plan(multicore/multisession) to speed up PLNPCA/PLNmixture/stability_selection.
data(barents)

## Run PLN with torch backend
barents_torch <- PLN(Abundance ~ 1, data = barents, 
                     control = PLN_param(backend = "torch", config_optim = list(lr = 0.01)))
#> 
#>  Initialization...
#>  Adjusting a full covariance PLN model with torch optimizer
#>  Post-treatments...
#>  DONE!
## Extract ELBO values along iterations
barents_torch$optim_par$objective
#>   [1] -289968.9 -290266.2 -290572.5 -290889.0 -291201.1 -291503.0 -291795.5
#>   [8] -292077.9 -292341.1 -292576.4 -292781.8 -292956.7 -293102.4 -293224.4
#>  [15] -293328.0 -293419.7 -293502.2 -293579.1 -293648.1 -293707.4 -293756.3
#>  [22] -293793.8 -293823.8 -293847.6 -293869.1 -293883.7 -293901.3 -293915.1
#>  [29] -293925.2 -293936.9 -293944.9 -293952.4 -293959.2 -293965.0 -293970.3
#>  [36] -293974.9 -293979.2 -293983.2 -293986.9 -293990.4 -293993.5 -293996.5
#>  [43] -293999.1 -294001.6 -294004.0 -294006.2 -294008.4 -294010.3 -294012.2
#>  [50] -294013.9 -294015.4 -294016.9 -294018.2 -294019.5 -294020.6 -294021.7
#>  [57] -294022.8 -294023.8 -294024.7 -294025.6 -294025.9 -294027.1 -294028.0
#>  [64] -294028.7 -294029.6 -294030.3 -294031.1 -294031.8 -294032.5 -294033.2
#>  [71] -294033.8 -294034.1 -294034.9 -294035.4 -294035.9 -294036.5 -294036.9
#>  [78] -294037.3 -294037.8 -294038.2 -294038.5 -294038.9 -294039.3 -294039.6
#>  [85] -294039.9 -294040.2 -294040.5 -294040.8 -294041.1 -294041.4 -294041.6
#>  [92] -294041.9 -294042.2 -294042.4 -294042.7 -294042.9 -294043.2 -294043.4
#>  [99] -294043.7 -294043.9 -294044.2 -294044.4 -294044.7 -294045.0 -294045.2
#> [106] -294045.5 -294045.8 -294046.0 -294046.2 -294046.3 -294046.6 -294046.8
#> [113] -294047.0 -294047.2 -294047.4 -294047.6 -294047.8 -294047.9 -294048.1
#> [120] -294048.3 -294048.4 -294048.6 -294048.8 -294048.9 -294049.1 -294049.2
#> [127] -294049.4 -294049.6 -294049.7 -294049.9 -294050.1 -294050.2 -294050.4
#> [134] -294050.5 -294050.7 -294050.8 -294051.0 -294051.2 -294051.4 -294051.5
#> [141] -294051.6 -294051.8 -294052.0 -294052.1 -294052.3 -294052.4 -294052.6
#> [148] -294052.7 -294052.9 -294053.1 -294053.2 -294053.3 -294053.5 -294053.7
#> [155] -294053.8 -294053.9 -294054.0 -294054.2 -294054.3 -294054.4 -294054.5
#> [162] -294054.7 -294054.8 -294054.9 -294055.0 -294055.2 -294055.3 -294055.4
#> [169] -294055.6 -294055.7 -294055.8 -294056.0 -294056.1 -294056.3 -294056.4
#> [176] -294056.5 -294056.7 -294056.8 -294056.9 -294057.0 -294057.1 -294057.2
#> [183] -294057.4 -294057.5 -294057.6 -294057.8 -294057.9 -294058.0 -294058.2
#> [190] -294058.2 -294058.4 -294058.5 -294058.6 -294058.7 -294058.9 -294058.9
#> [197] -294059.1 -294059.2 -294059.3 -294059.4 -294059.6 -294059.7 -294059.8
#> [204] -294059.9 -294060.0 -294060.1 -294060.2 -294060.3 -294060.4 -294060.6
#> [211] -294060.7 -294060.8 -294060.9 -294061.0 -294061.2 -294061.2 -294061.3
#> [218] -294061.5 -294061.6 -294061.7 -294061.8 -294061.9 -294062.0 -294062.1
#> [225] -294062.2 -294062.3 -294062.4 -294062.5 -294062.7 -294062.8 -294062.9
#> [232] -294063.0 -294063.1 -294063.2 -294063.3 -294063.4 -294063.5 -294063.6
#> [239] -294063.7 -294063.8 -294063.9 -294064.0 -294064.1 -294064.2 -294064.3
#> [246] -294064.4 -294064.5 -294064.6 -294064.7 -294064.8 -294064.9 -294065.0
#> [253] -294065.1 -294065.2 -294065.3 -294065.4 -294065.5 -294065.6 -294065.7
#> [260] -294065.8 -294065.8 -294065.9 -294066.0 -294066.1 -294066.2 -294066.3
#> [267] -294066.4 -294066.5 -294066.6 -294066.7 -294066.8 -294066.8 -294067.0
#> [274] -294067.0 -294067.2 -294067.2 -294067.3 -294067.4 -294067.5 -294067.6
#> [281] -294067.7 -294067.8 -294067.9 -294068.0 -294068.1 -294068.2 -294068.3
#> [288] -294068.4 -294068.5 -294068.6 -294068.7 -294068.8 -294068.9 -294069.0
#> [295] -294069.1 -294069.2 -294069.2 -294069.3 -294069.4 -294069.5 -294069.6
#> [302] -294069.7 -294069.7 -294069.8 -294069.9 -294070.1 -294070.1 -294070.2
#> [309] -294070.3 -294070.4 -294070.5 -294070.6 -294070.8 -294070.9 -294070.9
#> [316] -294071.1 -294071.1 -294071.2 -294071.2

Created on 2023-08-21 with reprex v2.0.2

@Ananyapam7
Copy link
Author

Thank you very much for addressing my issue, I really appreciate it. These examples above are quite useful, thank you. Could you also tell me why nlopt backend seems to always converge despite the starting values however the torch backend might diverge depending on the initialization? Also, the elbo extracted above using barents_torch$optim_par$objective, does this value of the ELBO include the constants like e or pi, or are the constants not included in computing the ELBO?

@mahendra-mariadassou
Copy link
Collaborator

nlopt doesn't always converge. For example, we encountered instances where torch (with default parameters) converged while nlopt didn't. It all boils down to the optimization strategy (starting point and optimization parameters such as learning rate, weight decay, etc). Basically, (i) the closer to a stationary point and (ii) the smaller the learning rate, the more likely you are to converge. But it's not always easy to chose a starting point close to a stationary point and small learning rates lead to longer convergence times so there is a trade-off.

For nlopt, the default strategy is CCSAQ (you can read here for more details on the algorithm) which has good convergence properties and seems to be more robust than RPROP (the default we use for torch).

Regarding the ELBO, the detailed formula is given in:

torch_elbo = function(data, params, index=torch_tensor(1:self$n)) {
S2 <- torch_square(params$S[index])
Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B)
res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) +
sum(data$w[index,NULL] * (torch_exp(Z + .5 * S2) - data$Y[index] * Z - .5 * torch_log(S2)))
res
},

which, as you can see, doesn't involve the constants but only the terms depending on the variational and model parameters.

@jchiquet
Copy link
Member

HI,

Just to complete Mahendra's answer: we use the torch_elbo function for optimisation, but to represent and send back the final value in optim_par we use torch_vloglik (variational loglikelihood) which includes all constants, as seen here

torch_vloglik = function(data, params) {
S2 <- torch_square(params$S)
Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y))) + as.numeric(
.5 * torch_logdet(params$Omega) +
torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) -
.5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2)
)
attr(Ji, "weights") <- as.numeric(data$w)
Ji
},

To get additional information about the status of convergence of the algorithm, (including nlopt), check barents_torch$optim_par$status

> PLNmodels:::status_to_message(out$optim_par$status)
[1] "success, ftol_rel or ftol_abs was reached"

@Ananyapam7
Copy link
Author

Ananyapam7 commented Oct 22, 2023

Hi, thanks a lot for your insights. I wanted to dive a bit deeper in the code and had some more questions I wanted to ask- In the paper on the Variational Inference for Probabilistic Poisson PCA, I was reading the Section 5.3, in the implementation details, it is mentioned that the initialisation is a bit tricky and is done in a certain way. Can you point me in the code where it is done (for the nlopt and torch backends for the C++ )

@mahendra-mariadassou
Copy link
Collaborator

Hi, for PLNPCA it's a bit involved but the crux is here:

initialize = function(ranks, responses, covariates, offsets, weights, formula, control) {
## initialize the required fields
super$initialize(responses, covariates, offsets, weights, control)
private$params <- ranks
## save some time by using a common SVD to define the inceptive models
control$inception <- PLNfit$new(responses, covariates, offsets, weights, formula, control)
control$svdM <- svd(control$inception$var_par$M, nu = max(ranks), nv = ncol(responses))
## instantiate as many models as ranks
self$models <- lapply(ranks, function(rank){
model <- PLNPCAfit$new(rank, responses, covariates, offsets, weights, formula, control)
model
})
},

For PLNPCA, we have fit a family of models (one per number of dimension in the tested range) and all of them are initialised simultaneously using a three-step procedure:

  • Initialize a standard (full-rank) PLNmodel
    super$initialize(responses, covariates, offsets, weights, control)

    Which refers to the initialize of the PLN class, that is either an already computed PLN model (inception) or independent linear model on the log-counts (scaled by the offsets) for each feature, here:

    PLNmodels/R/PLNfit-class.R

    Lines 289 to 311 in 2657e50

    initialize = function(responses, covariates, offsets, weights, formula, control) {
    ## problem dimensions
    n <- nrow(responses); p <- ncol(responses); d <- ncol(covariates)
    ## set up various quantities
    private$formula <- formula # user formula call
    ## initialize the variational parameters
    if (isPLNfit(control$inception)) {
    if (control$trace > 1) cat("\n User defined inceptive PLN model")
    stopifnot(isTRUE(all.equal(dim(control$inception$model_par$B), c(d,p))))
    private$Sigma <- control$inception$model_par$Sigma
    private$B <- control$inception$model_par$B
    private$M <- control$inception$var_par$M
    private$S <- control$inception$var_par$S
    } else {
    if (control$trace > 1) cat("\n Use LM after log transformation to define the inceptive model")
    fits <- lm.fit(weights * covariates, weights * log((1 + responses)/exp(offsets)))
    private$B <- matrix(fits$coefficients, d, p)
    private$M <- matrix(fits$residuals, n, p)
    private$S <- matrix(.1, n, p)
    }
    private$optimizer$main <- ifelse(control$backend == "nlopt", nlopt_optimize, private$torch_optimize)
    private$optimizer$vestep <- nlopt_optimize_vestep
    },
  • With those M in hands, we then apply a SVD to the M matrix to define an inception model
    control$inception <- PLNfit$new(responses, covariates, offsets, weights, formula, control)
    control$svdM <- svd(control$inception$var_par$M, nu = max(ranks), nv = ncol(responses))
  • This inception is then used to initialize each PLNPCA(fit) model
    model <- PLNPCAfit$new(rank, responses, covariates, offsets, weights, formula, control)

    Each of those model will call its own initialize function and since we provided a (low-rank) starting point for M (in the form of svdM), it will use it to initialize its own M, S and C parameters:
    if (!is.null(control$svdM)) {
    svdM <- control$svdM
    } else {
    svdM <- svd(private$M, nu = rank, nv = self$p)
    }
    ### TODO: check that it is really better than initializing with zeros...
    private$M <- svdM$u[, 1:rank, drop = FALSE] %*% diag(svdM$d[1:rank], nrow = rank, ncol = rank) %*% t(svdM$v[1:rank, 1:rank, drop = FALSE])
    private$S <- matrix(0.1, self$n, rank)
    private$C <- svdM$v[, 1:rank, drop = FALSE] %*% diag(svdM$d[1:rank], nrow = rank, ncol = rank)/sqrt(self$n)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants