Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug when calculating weight matrix for SIMMAP trees #1

Open
kopperud opened this issue Feb 14, 2024 · 0 comments
Open

bug when calculating weight matrix for SIMMAP trees #1

kopperud opened this issue Feb 14, 2024 · 0 comments

Comments

@kopperud
Copy link
Owner

kopperud commented Feb 14, 2024

In slouch version 2.1.4 there is a bug in calculating the weight matrix for a model with discrete predictors mapped as regimes on the tree. The reason we discovered this, is that we were working on an alternative (faster) implementation of the likelihood. I'm copy-pasting a minimum-working example below

## minimum working example
library(ape)
library(phytools)
library(slouch)

############################################################
#                                                          #
#            State-dependent pruning                       #
# !! σ2, α are scalars, θ is a vector of length # states   #   
#                                                          #
############################################################

sd_postorder <- function(node_index, edge, tree, continuousChar,
                         μ, V, log_norm_factor, subedges_lengths, σ2, α, θ){
  ntip = length(tree$tip.label)
  
  # if is internal node
  if (node_index > ntip){
    
    left_edge  = which(edge[,1] == node_index)[1] # index of left child edge
    right_edge = which(edge[,1] == node_index)[2] # index of right child edge
    left = edge[left_edge,2] # index of left child node
    right = edge[right_edge,2] # index of right child node
    
    output_left <- sd_postorder(left, edge, tree, continuousChar,
                                μ, V, log_norm_factor, subedges_lengths, σ2, α, θ)
    μ <- output_left[[1]]
    V <- output_left[[2]]
    log_norm_factor <- output_left[[3]]
    
    output_right <- sd_postorder(right, edge, tree, continuousChar,
                                 μ, V, log_norm_factor, subedges_lengths, σ2, α, θ)
    μ <- output_right[[1]]
    V <- output_right[[2]]
    log_norm_factor <- output_right[[3]]
    
    sub_bl_left = subedges_lengths[left_edge][[1]] # all subedges of left child edge
    sub_bl_right = subedges_lengths[right_edge][[1]] # all subedges of right child edge
    
    # for the sake of readability, computation of variance, mean, and log_nf are done in separate loops
    # 1) variance of the normal variable: this branch (v_left) and the subtree (V[left])
    ## Is 'delta_left* exp(2.0 * α * bl_left)' added in each sub-edge?
    
    delta_left = V[left]
    v_left = 0 # initialise v_left
    for (i in rev(1:length(sub_bl_left))){
      state <- names(sub_bl_left)[i]
      delta_t <- sub_bl_left[[i]]
      v_left = σ2/(2*α) * expm1(2.0*α * delta_t) 
      delta_left = v_left + delta_left * exp(2.0 * α * delta_t)
    }
    
    delta_right = V[right]
    v_right = 0 # initialise v_right
    for (i in rev(1:length(sub_bl_right))){
      state <- names(sub_bl_right)[i]
      v_right = σ2/(2*α) *expm1(2.0*α*sub_bl_right[[i]])
      delta_right = v_right + delta_right * exp(2.0 * α * sub_bl_right[[i]])
    }
    
    var_left = delta_left
    var_right = delta_right
    
    # 2) mean of the normal variable
    mean_left = μ[left]
    for (i in rev(1:length(sub_bl_left))){
      state <- names(sub_bl_left)[i]
      mean_left = exp(α*sub_bl_left[[i]])*(mean_left - θ[[state]]) + θ[[state]]
    }
    
    mean_right = μ[right]
    for (i in rev(1:length(sub_bl_right))){
      state <- names(sub_bl_right)[i]
      mean_right = exp(α*sub_bl_right[[i]])*(mean_right - θ[[state]]) + θ[[state]]
    }
    
    ## compute the mean and variance of the node
    mean_ancestor = (mean_left * var_right + mean_right * var_left) / (var_left + var_right)
    μ[node_index] = mean_ancestor
    var_node = (var_left * var_right) / (var_left + var_right)
    V[node_index] = var_node
    
    ## compute the normalizing factor, the left-hand side of the pdf of the normal variable
    log_nf_left = 0
    for (i in rev(1:length(sub_bl_left))){
      state <- names(sub_bl_left)[i]
      delta_t <- sub_bl_left[[i]]
      log_nf_left = log_nf_left + delta_t * α
    }
    
    log_nf_right = 0
    for (i in rev(1:length(sub_bl_right))){
      state <- names(sub_bl_right)[i]
      log_nf_right = log_nf_right + sub_bl_right[[i]] * α
    }
    
    contrast = mean_left - mean_right
    a = -(contrast*contrast / (2*(var_left+var_right)))
    b = log(2*pi*(var_left+var_right))/2.0
    log_nf = log_nf_left + log_nf_right + a - b
    log_norm_factor[node_index] = log_nf
    
    return(list(μ, V, log_norm_factor))
  }
  # if is tip
  else{
    species = tree$tip.label[node_index]
    
    μ[node_index] = continuousChar[[which(names(continuousChar) == species)]]
    V[node_index] = 0.0 ## if there is no observation error
    
    return(list(μ, V, log_norm_factor))
  }
}

sd_logL_pruning <- function(tree, continuousChar, σ2, α, θ){
  ntip = length(tree$tip.label) # number of tips
  edge = tree$edge # equals tree[:edge] in Julia
  n_edges = length(edge[,1]) # number of edges
  max_node_index = max(tree$edge) # total number of nodes
  
  V = numeric(max_node_index)
  μ = numeric(max_node_index)
  log_norm_factor = numeric(max_node_index)
  
  subedges_lengths = tree$maps
  
  root_index = ntip + 1
  
  output <- sd_postorder(root_index, edge, tree, continuousChar,
                         μ, V, log_norm_factor, subedges_lengths, σ2, α, θ)
  μ <- output[[1]]
  V <- output[[2]]
  log_norm_factor <- output[[3]]
  
  ## assume root value equal to theta
  μ_root = μ[root_index]
  v_root = V[root_index]
  left_edge_from_root <- which(edge[,1] == ntip+1)[1] # obtain left child edge index of root node
  left_subedges_from_root <- subedges_lengths[[left_edge_from_root]] # obtain sub-edge lengths
  root_state = names(tail(left_subedges_from_root))[[1]] # obtain root state, assuming it equals last state at left child edge
  lnl = dnorm(θ[[root_state]], mean = μ_root, sd = sqrt(v_root), log = TRUE)
  
  ## add norm factor
  for (log_nf in log_norm_factor){
    lnl = lnl + log_nf
  }
  return(lnl)
}

And some code to test the two versions of computing the likelihood

###################################################
#                                                 #
#                     Testing...                  #   
#                                                 #
###################################################

# test with slouch data set
data("artiodactyla")
data("neocortex")

# convert continuous data to read.nexus.data() format
brain <- list()
for (i in 1:length(neocortex$brain_mass_g_log_mean)){
  sp <- neocortex$species[i]
  brain[sp] <- list(neocortex$brain_mass_g_log_mean[i])
}

neocortex <- neocortex[match(artiodactyla$tip.label, neocortex$species), ]
diet <- as.character(neocortex$diet)
names(diet) <- neocortex$species
set.seed(123)
smaptree <- phytools::make.simmap(artiodactyla, diet)

This is what the SIMMAP tree looks like (it's just one random SIMMAP, but it does not matter for illustrating that the likelihoods are equivalent)

Screenshot from 2024-02-14 15-21-03

m0 <- slouch.fit(
        smaptree,
        species = neocortex$species,
        response = neocortex$brain_mass_g_log_mean,
        fixed.fact = neocortex$diet,
        a_values = 0.01,
        sigma2_y_values = 1.0,
        anc_maps = "simmap",
        hillclimb = FALSE
    )


theta <- dput(m0$beta_primary$coefficients[,1])
lnl_brain_pruning <- sd_logL_pruning(smaptree, brain,
                                     σ2 = 1.0,
                                     α = 0.01, θ = theta)

As of version 2.1.4, these two methods gave different results (i.e. different probability densities). There is a fix in 88c6181. When printing the likelihoods after the bug fix, in version 2.1.5 (I will upload a new version to CRAN), the likelihoods are identical up to floating point errors

print(m0$modfit$Support, digits = 20)
print(lnl_brain_pruning, digits = 20)
[1] -89.861374999422423571
[1] -89.86137499942240936
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant