In [14]:
using TorusEvol
using Distributions
using TimerOutputs
using Plots, StatsPlots
using LogExpFunctions

chainY = from_pdb("1A3N", "A")
chainZ = from_pdb("1MBN", "A")
Y = data(chainY)
Z = data(chainZ)

data(Y)[2][isnan.(data(Y)[2])] .= 0.0
data(Z)[2][isnan.(data(Z)[2])] .= 0.0

my_diff = jumping(WrappedDiffusion(-1.2, -0.8, 0.4, 0.4, 1.0, 1.0, 0.2), 1.0)
my_diff2 = jumping(WrappedDiffusion(-1.0, -1.0, 1.8, 1.8, 1.0, 1.0, 0.1), 1.0)

w = WAG_SubstitutionProcess()
processes = [w w; my_diff my_diff2]
weights = [0.85, 0.15]
ξ = MixtureProductProcess(weights, processes)

t = 1.58
λ = 0.03; μ = 0.032; r = 0.3
align_model = TKF92([t], λ, μ, r; known_ancestor=true)
pair_hmm = PairDataHMM(align_model, num_sites(Y), num_sites(Z));

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1MBN


In [15]:
function logpdfs(alignment, emission_lps)
    res = Real[]
    N, M = size(emission_lps) .- 1
    i = 0
    j = 0
    for v ∈ alignment 
        if v == [1, 1]
            i += 1
            j += 1
            push!(res, emission_lps[i, j])
        elseif v == [1, 0]
            i += 1
            push!(res, emission_lps[i, M+1])
        else
            j += 1
            push!(res, emission_lps[N+1, j])
        end     
    end
    return res
end


logpdfs (generic function with 1 method)

In [12]:
emission_lps = fulllogpdf(ξ, t, Y, Z)
logpdf(pair_hmm, emission_lps)
alignment = backward_sampling(pair_hmm.α, align_model)
align_emission_lps = logpdfs(alignment, emission_lps)
anc_align_model = TKF92([t/2, t/2], λ, μ, r; known_ancestor=false)
α = forward_anc(alignment, anc_align_model, align_emission_lps)

-1505.2209433608903

In [3]:
model = TKF92([t/2, t/2], λ, μ, r; known_ancestor=false)
pair_hmm = PairDataHMM(model, num_sites(Y), num_sites(Z))
reset_timer!(to)

@timeit to "emission lps" emission_lps = fulllogpdf(ξ, t, Y, Z)

@timeit to "forward" _logpdf(pair_hmm, emission_lps; optimized=true)
print(to)

[0m[1m ──────────────────────────────────────────────────────────────────────────[22m
[0m[1m                         [22m         Time                    Allocations      
                         ───────────────────────   ────────────────────────
    Tot / % measured:         1.53s /  99.9%            372MiB / 100.0%    

 Section         ncalls     time    %tot     avg     alloc    %tot      avg
 ──────────────────────────────────────────────────────────────────────────
 forward              1    1.52s   99.1%   1.52s    366MiB   98.4%   366MiB
   logsumexp α     261k   55.1ms    3.6%   211ns   3.98MiB    1.1%    16.0B
 emission lps         1   13.5ms    0.9%  13.5ms   5.82MiB    1.6%  5.82MiB
   jointlogpdf        1   4.28ms    0.3%  4.28ms   4.28MiB    1.2%  4.28MiB
   logsumexp          3   4.24ms    0.3%  1.41ms    351KiB    0.1%   117KiB
   statlogpdf         2   57.6μs    0.0%  28.8μs   22.0KiB    0.0%  11.0KiB
[0m[1m ────────────────────────────────────────────────────

In [4]:
using Turing 
using ReverseDiff
using TorusEvol
#Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.setadbackend(:forwarddiff)
Turing.setprogress!(true)

@model function ancestor_sampling(alignment, dataY, dataZ, t, ξ)
    C = num_coords(ξ)
    E = num_regimes(ξ)
    
    
    #dataX = 
    

    X_mask = mask(alignment, [[1], [0,1], [0,1]])
    alignmentX = alignment[X_mask]
    Y_mask = mask(alignment, [[0,1], [1], [0,1]])
    alignmentY = alignment[Y_mask]
    Z_mask = mask(alignment, [[0,1], [0,1], [1]])
    alignmentZ = alignment[Z_mask]
    
    M = length(alignment)
    regimes = tzeros(Int, M)
    regimes ~ filldist(Categorical(weights(ξ)), M)
    regimesX = regimes[X_mask]
    regimesY = regimes[Y_mask]
    regimesZ = regimes[Z_mask]
    
    X_maskX = mask(alignmentX, [[1], [0], [0]])
    
    XY_maskX = mask(alignmentX, [[1], [1], [0]])
    XY_maskY = mask(alignmentY, [[1], [1], [0]])
    
    XZ_maskX = mask(alignmentX, [[1], [0], [1]])
    XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])
    
    XYZ_maskX = mask(alignmentX, [[1], [1], [1]])
    XYZ_maskY = mask(alignmentY, [[1], [1], [1]])
    XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])
    
    for c ∈ 1:C, e ∈ 1:E
        p = processes(ξ)[c, e]
        # [1, 0, 0] - sample from stationary distribution
        data[c][:, X_maskX .& regimesX .== e] .~ statdist(p)
        
        # [1, 1, 0] - observe Y, then sample X from Y
        dataY[c][:, XY_maskY .& regimesY .== e] .~ statdist(p)
        dataX[c][:, XY_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataY[c][:, XY_maskY .& regimesY .== e])))
        
        # [1, 0, 1] - observe Z, then sample X from Z
        dataY[c][:, XZ_maskZ .& regimesZ .== e] .~ statdist(p)
        dataX[c][:, XZ_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataZ[c][:, XY_maskZ .& regimesZ .== e])))
        
        # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X
        dataY[c][:, XYZ_maskY .& regimesY .== e] .~ statdist(p)
        dataX[c][:, XYZ_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataY[c][:, XYZ_maskY .& regimesY .== e])))
        dataZ[c][:, XYZ_maskZ .& regimesZ .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataX[c][:, XYZ_maskX .& regimesX .== e])))
    end
end
    

InterruptException: InterruptException: