In [1]:
using LinearAlgebra, Distributions
include("tree_simulation_v2.jl")
include("compute_magnitudes_of_the_dynamics.jl")

compute_magnitudes (generic function with 1 method)

In [2]:
#SOlo se tienen en cuenta los nodos internos para la reconstruccion
function fill_adj(k::Int64, c::Int64)
    # Total number of nodes in a complete tree with branching factor c and k+1 levels (excluding leaves).
    internal_nodes = (c^k - 1) ÷ (c - 1)
    
    # Initialize an adjacency list as a dictionary, only for internal nodes.
    adj_list = Dict{Int64, Array{Int64, 1}}()

    # Initialize the adjacency list for each internal node in the tree.
    for node in 0:internal_nodes-1
        adj_list[node] = Int64[]
    end

    # Build the adjacency list by linking each node to its children and setting parent relationships.
    for node in 0:internal_nodes-1
        for i in 1:c
            child = c * node + i
            if child < internal_nodes
                # Add the child node to the current node's list of children.
                push!(adj_list[node], child)
                push!(adj_list[child], node)
            end
        end
    end

    # Return the completed adjacency list for internal nodes only.
    return adj_list
end


fill_adj (generic function with 1 method)

In [5]:
# Parámetros
q = 2  # número de estados
L = 10  # longitud de la secuencia
r = 3  # conectividad
k = 9  # profundidad del árbol k+1
c = 2  # número de descendientes por nodo
β = 0.1  # inverso de la temperatura
t = rand(1:9, k+1)  # tiempos evolutivos

# Crear el diccionario con Symbol como claves
tree_params = Dict{Symbol, Any}()

# Llenar el diccionario con los valores dados
tree_params[:q] = q
tree_params[:L] = L
tree_params[:r] = r
tree_params[:k] = k
tree_params[:c] = c
tree_params[:β] = β
tree_params[:times] = t

gamma = 1.0

1.0

In [7]:
adj_list = fill_adj(k,c);

Dict{Int64, Vector{Int64}} with 511 entries:
  56  => [27, 113, 114]
  35  => [17, 71, 72]
  425 => [212]
  429 => [214]
  60  => [29, 121, 122]
  220 => [109, 441, 442]
  308 => [153]
  67  => [33, 135, 136]
  215 => [107, 431, 432]
  73  => [36, 147, 148]
  319 => [159]
  251 => [125, 503, 504]
  115 => [57, 231, 232]
  112 => [55, 225, 226]
  185 => [92, 371, 372]
  348 => [173]
  420 => [209]
  404 => [201]
  365 => [182]
  417 => [208]
  333 => [166]
  86  => [42, 173, 174]
  168 => [83, 337, 338]
  364 => [181]
  207 => [103, 415, 416]
  ⋮   => ⋮

In [9]:
@time tree, leaves, C, data_mean = run_simulation(tree_params);

  6.266041 seconds (4.57 M allocations: 306.797 MiB, 4.23% gc time, 99.57% compilation time)


In [None]:
@time times, Lambda, Sigma, J, H, h = compute_magnitudes(tree, leaves, C, gamma, t, q, L);

In [None]:
function edges(tree::Dict{Int64,Dict{Symbol,Any}})
    upward_edges = Array{Tuple{Int64,Int64},1}()
    downward_edges = Array{Tuple{Int64,Int64},1}()
    #max_node = maximum(keys(tree))
    internal_nodes = [node for (node,info) in tree if !info[:leaf]]
    max_node = maximum(internal_nodes)
    
    for (node,info) in tree
        if !info[:leaf] # Solo proceder si no es hoja
            for child in info[:children]
                if child <= max_node
                    push!(upward_edges, (child, node))
                    push!(downward_edges, (node, child))
                end
            end
        end
    end
    return upward_edges, downward_edges
end 

# Definir magnitudes necesarias para correr BP


In [None]:
# media y matriz de precision correspondientes al termino que depende unicamente de x_i
function local_parameters(tree::Dict{Int64,Dict{Symbol,Any}}, H::Dict{Int64,Array{Float64,2}}, h::Dict{Int64,Array{Float64,1}})

    mu_local = Dict{Int64, Array{ Float64, 1}}()
    inv_sigma_local = Dict{Int64, Array{ Float64, 2}}()

    for node in keys(tree)
        if tree[node][:leaf]
            continue
        end
        H_inv = inv(H[node])
        
        mu_local[node] = -0.5*H_inv*h[node]
        inv_sigma_local[node] = -0.5*H_inv

    end
    return mu_local, inv_sigma_local
end     

In [None]:
function init_messages!(message_mean::Dict{Tuple{ Int64, Int64}, Array{ Float64, 1}},
        inv_message_precision::Dict{Tuple{ Int64, Int64}, Array{ Float64, 2}},
        tree::Dict{Int64,Dict{Symbol,Any}},
        q::Int64,
        L::Int64
    )

    #=
    Ajustar luego la forma de iniciar los mensajes, 
    lei que tambien se pueden iniciar con valores empiricos extraidos de los datos, 
    como la media y la matriz de correlacion empiricas.
    =#
    for i in keys(tree)
        if tree[i][:leaf]
            continue
        end
        
        for j in keys(tree)
            if tree[j][:leaf]
                continue
            end
            
            message_mean[i,j] = zeros(q*L)
            inv_message_precision[i,j] = Matrix(I, q*L, q*L)

        end
    end 
end       

# Mensaje de i a j
function update_messages!(i,j,
        message_mean::Dict{Tuple{ Int64, Int64}, Array{ Float64, 1}},
        inv_message_precision::Dict{Tuple{ Int64, Int64}, Array{ Float64, 2}},
        mu_local::Dict{Int64, Array{ Float64, 1}},
        inv_sigma_local::Dict{Int64, Array{ Float64, 2}},
        J::Dict{Tuple{Int64,Int64},Array{Float64,2}},
        adj_list::Dict{Int64,Array{Int64,1}} 
    )
    neightborgs_sum_presition = sum(inv_message_precision[k,i] for k in adj_list[i] if k != j; init=zeros(size(inv_message_precision[i,j])))
    inv_message_precision[i,j] = -(J[i,j])^2*inv(inv_sigma_local[i] + neightborgs_sum_presition)

    neightborgs_sum_mean = sum(inv_message_precision[k,i]*message_mean[k,i] for k in adj_list[i] if k != j; init=zeros(size(message_mean[i,j])))
    message_mean[i,j]= -inv(J[i,j])*(inv_sigma_local[i]*mu_local[i] + neightborgs_sum_mean)
    
end

# Esta funcion hace un primer pase donde recoge los mensajes enviados desde los nodos hijos a los padres y luego se hace la actualizacion de los mensajes en sentido contrario
function sweep!(upward_edges::Array{Tuple{Int64,Int64},1},
        downward_edges::Array{Tuple{Int64,Int64},1},
        message_mean::Dict{Tuple{ Int64, Int64}, Array{ Float64, 1}},
        inv_message_precision::Dict{Tuple{ Int64, Int64}, Array{ Float64, 2}},
        mu_local::Dict{Int64, Array{ Float64, 1}},
        inv_sigma_local::Dict{Int64, Array{ Float64, 2}},
        J::Dict{Tuple{Int64,Int64},Array{Float64,2}},
        adj_list::Dict{Int64,Array{Int64,1}} 
    )

    for (child_u,parent_u) in upward_edges
        update_messages!(child_u, parent_u, message_mean, inv_message_precision, mu_local, inv_sigma_local, J, adj_list)
    end 

    for (parent_d,child_d) in downward_edges
        update_messages!(parent_d, child_d, message_mean, inv_message_precision, mu_local, inv_sigma_local, J, adj_list)
    end   

end

function message_passing(tree::Dict{Int64,Dict{Symbol,Any}},
        adj_list::Dict{Int64,Array{Int64,1}},
        mu_local::Dict{Int64, Array{ Float64, 1}},
        inv_sigma_local::Dict{Int64, Array{ Float64, 2}},  
        J::Dict{Tuple{Int64,Int64},Array{Float64,2}},
        q::Int64,
        L::Int64
    )
    message_mean = Dict{Tuple{ Int64, Int64}, Array{ Float64, 1}}()
    inv_message_precision = Dict{Tuple{ Int64, Int64}, Array{ Float64, 2}}()

    upward_edges, downward_edges = edges(tree) 
    mu_local, inv_sigma_local = local_parameters(tree, H, h) # parametros del campo local
    
    init_messages!(message_mean, inv_message_precision, tree, q, L)
    sweep!(upward_edges, downward_edges, message_mean, inv_message_precision, mu_local, inv_sigma_local, J, adj_list)
   
    return message_mean, inv_message_precision
end

In [None]:
function compute_Marginals(tree::Dict{Int64,Dict{Symbol,Any}},
        adj_list::Dict{Int64,Array{Int64,1}},
        message_mean::Dict{Tuple{ Int64, Int64}, Array{ Float64, 1}},
        inv_message_precision::Dict{Tuple{ Int64, Int64}, Array{ Float64, 2}},
        mu_local::Dict{Int64, Array{ Float64, 1}},
        inv_sigma_local::Dict{Int64, Array{ Float64, 2}},        
        q::Int64,
        L::Int64
    )
    marginal_mean = Dict{Int64, Array{ Float64, 1}}()
    marginal_precision = Dict{Int64, Array{ Float64, 2}}()

    for node in keys(tree)
        if tree[node][:leaf]
            continue
        end
        #inversa de la matriz de correlacion
        neightborgs_sum_precision = sum(inv_message_precision[k,node] for k in adj_list[node]; init=zeros(q*L,q*L))
        marginal_precision[node] =  inv_sigma_local[node] + neightborgs_sum_precision

        # Forzar que la matriz de precisión marginal sea simétrica
        marginal_precision[node] = (marginal_precision[node] + marginal_precision[node]') / 2


        neightborgs_sum_precision =  sum(inv_message_precision[k,node]*message_mean[k,node] for k in adj_list[node]; init=zeros(q*L))
        marginal_mean[node] = inv(marginal_precision[node])*(inv_sigma_local[node]*mu_local[node] + neightborgs_sum_precision)
    end

    return marginal_mean, marginal_precision
end

In [None]:
function sample_from_marginals(tree::Dict{Int64,Dict{Symbol,Any}},
        marginal_mean::Dict{Int64, Array{ Float64, 1}},
        marginal_precision::Dict{Int64, Array{ Float64, 2}},
        data_mean::Array{Float64,1}
    )
    Infered_seq = Dict{Int64, Array{Float64,1}}()

    internal_nodes = [node for (node,info) in tree if !info[:leaf]]
    max_node = maximum(internal_nodes)

    for node in 0:max_node
        mean = marginal_mean[node]
        covariance = inv(marginal_precision[node])
        
        dist = MvNormal( mean, (covariance*covariance')/2)

        # Samplear un vector
        sample = rand(dist)
        
        # Sumar el valor de global_mean a cada componente del vector sampleado
        sample_shifted = sample .+ data_mean 
        #sample_shifted = sample + gloabal_mean 

        Infered_seq[node] = sample_shifted
    end
    return Infered_seq
end 

# Al parecer hay algun problema porque por ejemplo la matriz de precision para el nodo raiz queda casi igual a la matriz de covarianza empirica, 
# segun tengo entendido las matrices que precision las inversas de las de covarianza. 
function sample_wrong(tree::Dict{Int64,Dict{Symbol,Any}},
        marginal_mean::Dict{Int64, Array{ Float64, 1}},
        marginal_precision::Dict{Int64, Array{ Float64, 2}},
        data_mean::Array{Float64,1}
    )
    Infered_seq = Dict{Int64, Array{Float64,1}}()

    internal_nodes = [node for (node,info) in tree if !info[:leaf]]
    max_node = maximum(internal_nodes)

    for node in 0:max_node
        mean = marginal_mean[node]
        covariance = marginal_precision[node]
        
        dist = MvNormal( mean, covariance)

        # Samplear un vector
        sample = rand(dist)
        
        # Sumar el valor de global_mean a cada componente del vector sampleado
        sample_shifted = sample .+ data_mean 
        #sample_shifted = sample + gloabal_mean 

        Infered_seq[node] = sample_shifted
    end
    return Infered_seq
end 

In [None]:
mu_local, inv_sigma_local = local_parameters(tree, H, h); # parametros del campo local
@time message_mean, inv_message_precision = message_passing(tree, adj_list, mu_local, inv_sigma_local, J, q, L);

In [None]:
marginal_mean, marginal_precision = compute_Marginals(tree, adj_list, message_mean, inv_message_precision, mu_local, inv_sigma_local, q, L)

In [None]:
function check_symmetric_matrices(dict)
    for (key, matrix) in dict
        # Verificar si la matriz es simétrica utilizando issymmetric
        if !issymmetric(matrix)
            println("La matriz asociada a la clave '$key' no es simétrica.")
        end
    end
end

In [None]:
check_symmetric_matrices(marginal_precision)

In [None]:
eigvals_A = eigen(marginal_precision[5]).values

# Imprimir los autovalores
println("Autovalores de la matriz:")
println(eigvals_A)

# Verificar si todos los autovalores son positivos
if all(eigvals_A .> 0)
    println("La matriz es definida positiva")
else
    println("La matriz no es definida positiva")
end

In [None]:
inv(marginal_precision[0])

In [None]:
marginal_precision[0]

In [None]:
Infered_seq = sample_from_marginals(tree, marginal_mean, marginal_precision, data_mean)

In [None]:
Infered_seq = sample_wrong(tree, marginal_mean, marginal_precision, data_mean)

In [None]:
tree[1][:seq]

In [None]:
leaves[9][:binary_seq] .+ data_mean