In [4]:
using TorusEvol
using Distributions
using TimerOutputs
using Plots, StatsPlots
using LogExpFunctions

chainY = from_pdb("1A3N", "A")
chainZ = from_pdb("1A3N", "B")
#Y = slice(data(chainY), 1:5, :)
#Z = slice(data(chainZ), 2:7, :)
Y = data(chainY)
Z = data(chainZ)

data(Y)[2][isnan.(data(Y)[2])] .= 0.0
data(Z)[2][isnan.(data(Z)[2])] .= 0.0

my_diff = jumping(WrappedDiffusion(-1.2, -0.8, 0.7, 0.9, 1.0, 1.0, 0.2), 1.0)
my_diff2 = jumping(WrappedDiffusion(-1.0, -1.0, 1.8, 1.8, 1.0, 1.0, 0.1), 1.0)

w = WAG_SubstitutionProcess()
processes = [w w; my_diff my_diff2]
weights = [1.0, 0.0]
ξ = MixtureProductProcess(weights, processes)

t = 0.4
λ = 0.015; μ = 0.0156; r = 0.6
align_model = TKF92([t], λ, μ, r; known_ancestor=true)
pair_hmm = PairDataHMM(align_model, num_sites(Y), num_sites(Z));
pair_chain_dist = ChainJointDistribution(ξ, align_model)
emission_lps = fulljointlogpdf(ξ, t, Y, Z)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N


142×146 Matrix{Float64}:
 -25.0358  -24.7162  -25.0068  -33.7361   …  -28.0232  -66.9887   -13.434
 -36.1451  -15.4942  -16.2317  -25.3725      -19.0279  -32.9741   -13.6597
 -33.6691  -21.2741  -14.5955  -25.4042      -20.5477  -44.5028   -11.7251
 -39.7601  -34.1013  -33.2484   -5.02986     -35.2034  -30.3247    -3.88709
 -40.4012  -32.9982  -32.0806   -7.61196     -34.7408  -27.4761    -3.17999
 -39.8164  -35.0303  -33.3116   -9.24855  …  -34.5567  -28.6569    -3.61262
 -39.9721  -33.9442  -32.4771   -8.79749     -35.2304  -27.4783    -3.5067
 -40.3626  -33.7294  -29.2066   -8.62386     -34.5341  -29.1127    -3.54312
 -39.4897  -34.8798  -32.3605   -9.88389     -33.899   -27.2323    -3.96756
 -42.5864  -31.472   -31.3898   -9.33567     -33.4383  -29.8048    -3.36862
 -39.4256  -33.5821  -32.9801   -8.75491  …  -35.4168  -27.1953    -3.52709
 -41.3326  -33.0187  -31.1069   -7.76237     -33.8885  -28.4409    -3.16163
 -41.5296  -33.063   -31.1095   -7.71654     -33.9511  -28.0034    -

In [20]:
function logpdfs(alignment, emission_lps)
    res = Real[]
    N, M = size(emission_lps) .- 1
    i = 0
    j = 0
    for v ∈ alignment 
        if v == [1, 1]
            i += 1
            j += 1
            push!(res, emission_lps[i, j])
        elseif v == [1, 0]
            i += 1
            push!(res, emission_lps[i, M+1])
        else
            j += 1
            push!(res, emission_lps[N+1, j])
        end     
    end
    return res
end

function logpdfsregime(alignment, regimes, Y, Z, ξ, t)
    M = length(alignment)
    C = num_coords(ξ)
    res = fill(0, m)
    for m ∈ 1:M 
        e = regimes[m]
        v = alignment[m]
        for c ∈ 1:C 
            p = processes(ξ)[c, e]
            if v == [1, 1]
                res[m] = logpdf(statdist(p), Y[m]) + logpdf(transdist(p, t, Y[m]), Z[m])
            elseif v == [1, 0]
                res[m] = logpdf(statdist(p), Y[m])
            else
                res[m] = logpdf(statdist(p), Z[m])
            end   
        end
    end
end

logpdfsregime (generic function with 1 method)

In [26]:
M_YZ = backward_sampling(pair_hmm.α, align_model)
function sample_anc_alignment(M_YZ, Y, Z, t, ξ)
    emission_lps = fulljointlogpdf(ξ, t, Y, Z)
    align_emission_lps = logpdfs(M_YZ, emission_lps)
    anc_align_model = TKF92([t/2, t/2], λ, μ, r; known_ancestor=false)
    α = forward_anc(M_YZ, anc_align_model, align_emission_lps)
    M_XYZ = backward_sampling_anc(α, anc_align_model)
    return M_XYZ
end
display(M_YZ.data)
M_XYZ = sample_anc_alignment(M_YZ, Y, Z, t, ξ)
display(M_XYZ.data)

2×153 Matrix{Integer}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1

3×153 Matrix{Integer}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1

In [165]:
using Turing 
using ReverseDiff
using TorusEvol
using LinearAlgebra
using StatsPlots
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.setprogress!(false)

@model function bisection_sampler_vec_mv(p, descs, t)
    n = size(descs[1], 2)
    D = length(descs)
    descs[1] ~ filldist(statdist(p), n)
    
    x = Matrix{Real}(undef, length(p), n)
    for i ∈ 1:n
        x[:, i] ~ transdist(p, t, vec(descs[1][:, i]))
        for d ∈ 2:D
            descs[d][:, i] ~ transdist(p, t, vec(x[:, i]))
        end
    end
end

@model function bisection_sampler_sub(p, descs, t)
    n = size(descs[1], 2)
    D = length(descs)
    descs[1] ~ filldist(statdist(p), n)
    
    x = Matrix{Real}(undef, length(p), n)
    for i ∈ 1:n
        x[:, i] ~ transdist(p, t, vec(descs[1][:, i]))
        for d ∈ 2:D
            descs[d][:, i] ~ transdist(p, t, vec(x[:, i]))
        end
    end
end

@model function bisection_sampler_wn(p, descs, t)
    D = length(descs)
    descs[1] ~ statdist(p)
    
    x = Vector{Real}(undef, length(p))
    x ~ transdist(p, t, descs[1])
    for d ∈ 2:D
        descs[d] ~ transdist(p, t, x)
    end
end

torus_proposal(vv) = [MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2]) for v ∈ vv]


function sample_anc_coords_wn(p, descs, t; burn_in=300)
    mhs = MH(:x => v -> MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2]))
    n = size(descs[1], 2)
    D = length(descs)
    X = Matrix{Real}(undef, length(p), n)
    for i ∈ 1:n
        descs_i = [descs[d][:, i] for d ∈ 1:D]
        model = bisection_sampler_wn(p, descs_i, t)
        chn = sample(model, mhs, burn_in+1)
        X[1, i] = get(chn, :x).x[1][burn_in+1]
        X[2, i] = get(chn, :x).x[2][burn_in+1]
    end
    return x
end

function sample_anc_coords(p, descs, t; burn_in=300)
    mhs = MH(:x => vv -> torus_proposal(vv))
    #par = PG(100, :x)
    #sampler = eltype(p) <: Integer ? par : mhs
    model = bisection_sampler_vec(p, descs, t)
    chn = sample(model, mhs, burn_in+1)
    return chn
end


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[Turing]: progress logging is disabled globally
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[AdvancedVI]: global PROGRESS is set as false


sample_anc_coords (generic function with 1 method)

In [167]:
X_angles = sample_anc_coords_wn(my_diff, [data(Y)[2], data(Z)[2]], t)
display(X_angles)
display(data(Y)[2])
display(data(Z)[2])

2×141 Matrix{Real}:
 -1.36342    -1.10835   -1.18526  …  -2.21012   -0.745595  -2.66263
  0.0718734  -0.945002   2.79838     -0.324282  -1.98636    0.0589801

2×141 Matrix{Real}:
 0.0     -1.56373  -1.32366  -0.987045  …  -1.86729   -1.18805   -2.78712
 2.7491   2.17697   2.91804  -0.760957      0.280312  -0.481882   0.0

2×145 Matrix{Real}:
 0.0      -1.24806  -1.47966  -1.29803   …  -1.67459   -1.28189  -2.73098
 1.93653   2.47863   3.11704  -0.591874      0.240513   2.62932   0.0

In [None]:
function ancestor_sampling(M_YZ, Y, Z, t, ξ)
    # step 1 - sample XYZ alignment
    M_XYZ = sample_anc_alignment(M_YZ, Y, Z, t, ξ)
    
    # step 2 - sample coordinates of X
    alignment = M_XYZ
    X_mask = mask(alignment, [[1], [0,1], [0,1]])
    alignmentX = alignment[X_mask]
    Y_mask = mask(alignment, [[0,1], [1], [0,1]])
    alignmentY = alignment[Y_mask]
    Z_mask = mask(alignment, [[0,1], [0,1], [1]])
    alignmentZ = alignment[Z_mask]
    
    M = length(alignment)
    regimes = ones(M)
    
    X_maskX = mask(alignmentX, [[1], [0], [0]])
    
    XY_maskX = mask(alignmentX, [[1], [1], [0]])
    XY_maskY = mask(alignmentY, [[1], [1], [0]])
    
    XZ_maskX = mask(alignmentX, [[1], [0], [1]])
    XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])
    
    XYZ_maskX = mask(alignmentX, [[1], [1], [1]])
    XYZ_maskY = mask(alignmentY, [[1], [1], [1]])
    XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])
    
    dataY = data(Y)
    dataZ = data(Z)
    
    # Initialise internal coordinates of X
    N = sequence_lengths(M_XYZ)[1]
    dataX = [similar(dataY[c], size(dataY[c], 1), N) for c ∈ 1:C]
    
    for c ∈ 1:C, e ∈ 1:E
        p = processes(ξ)[c, e]
        
        # [1, 0, 0] - sample from stationary distribution
        dataX100 = @view dataX[c][:, X_maskX .& regimesX .== e]
        n100 = size(dataX100, 2)
        
        dataX100 .= rand(statdist(p), n100)
        
        
        # [1, 1, 0] - observe Y, then sample X from Y
        dataY110 = @view dataY[c][:, XY_maskY .& regimesY .== e]
        dataX110 = @view dataX[c][:, XY_maskX .& regimesX .== e]
        
        dataX110 .= sample_anc_coords(p, [dataY110], t)
    
        
        # [1, 0, 1] - observe Z, then sample X from Z
        dataZ101 = @view dataZ[c][:, XZ_maskZ .& regimesZ .== e]
        dataX101 = @view dataX[c][:, XZ_maskZ .& regimesX .== e]
        
        dataX101 .= sample_anc_coords(p, [dataZ101], t)
        
        
        # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X
        dataY111 = @view dataY[c][:, XYZ_maskY .& regimesY .== e]
        dataZ111 = @view dataZ[c][:, XYZ_maskZ .& regimesZ .== e]
        dataX111 = @view dataX[c][:, XYZ_maskX .& regimesX .== e]
        
        dataX111 .= sample_anc_coords(p, [dataY111, dataZ111], t)
    end
    
    X = ObservedData(dataX)
end

function trajectory_reconstruction(Y, Z, align_model, ξ, t;levels=1)
    # sample alignment
    
    # for each level
    ancestor_sampling(M_YZ, regimes, Y, Z, t, ξ)
    trajectory_reconstruction(Y, X, align_model, ξ, t/2; levels=levels-1)
    trajectory_reconstruction(X, Z, align_model, ξ, t/2; levels=levels-1)
end

In [4]:
using Turing 
using ReverseDiff
using TorusEvol
#Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.setadbackend(:forwarddiff)
Turing.setprogress!(true)

@model function ancestor_sampling(alignment, regimes, dataY, dataZ, t, ξ)
    C = num_coords(ξ)
    E = num_regimes(ξ)
    
    
    #dataX = 
    

    X_mask = mask(alignment, [[1], [0,1], [0,1]])
    alignmentX = alignment[X_mask]
    Y_mask = mask(alignment, [[0,1], [1], [0,1]])
    alignmentY = alignment[Y_mask]
    Z_mask = mask(alignment, [[0,1], [0,1], [1]])
    alignmentZ = alignment[Z_mask]
    
    M = length(alignment)
    regimes = tzeros(Int, M)
    regimes ~ filldist(Categorical(weights(ξ)), M)
    regimesX = regimes[X_mask]
    regimesY = regimes[Y_mask]
    regimesZ = regimes[Z_mask]
    
    X_maskX = mask(alignmentX, [[1], [0], [0]])
    
    XY_maskX = mask(alignmentX, [[1], [1], [0]])
    XY_maskY = mask(alignmentY, [[1], [1], [0]])
    
    XZ_maskX = mask(alignmentX, [[1], [0], [1]])
    XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])
    
    XYZ_maskX = mask(alignmentX, [[1], [1], [1]])
    XYZ_maskY = mask(alignmentY, [[1], [1], [1]])
    XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])
    
    for c ∈ 1:C, e ∈ 1:E
        p = processes(ξ)[c, e]
        # [1, 0, 0] - sample from stationary distribution
        data[c][:, X_maskX .& regimesX .== e] .~ statdist(p)
        
        # [1, 1, 0] - observe Y, then sample X from Y
        dataY[c][:, XY_maskY .& regimesY .== e] .~ statdist(p)
        dataX[c][:, XY_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataY[c][:, XY_maskY .& regimesY .== e])))
        
        # [1, 0, 1] - observe Z, then sample X from Z
        dataY[c][:, XZ_maskZ .& regimesZ .== e] .~ statdist(p)
        dataX[c][:, XZ_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataZ[c][:, XY_maskZ .& regimesZ .== e])))
        
        # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X
        dataY[c][:, XYZ_maskY .& regimesY .== e] .~ statdist(p)
        dataX[c][:, XYZ_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataY[c][:, XYZ_maskY .& regimesY .== e])))
        dataZ[c][:, XYZ_maskZ .& regimesZ .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataX[c][:, XYZ_maskX .& regimesX .== e])))
    end
end
    

InterruptException: InterruptException: