In [1]:
using TorusEvol
using Distributions
using TimerOutputs

chainY = from_pdb("1A3N", "A"; ramachandran=false)
chainZ = from_pdb("1MBN", "A"; ramachandran=false)
Y = data(chainY)
Z = data(chainZ)

w = WAG_SubstitutionProcess
processes = [w;;]
weights = [1.0]
ξ = MixtureProductProcess(weights, processes)

t = 1.58
λ = 0.03; μ = 0.032; r = 0.3
align_model = TKF92([t], λ, μ, r)
pair_hmm = PairDataHMM(align_model, num_sites(Y), num_sites(Z));

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1MBN


In [2]:
model = TKF92([t/2, t/2], λ, μ, r; known_ancestor=false)
pair_hmm = PairDataHMM(model, num_sites(Y), num_sites(Z))
logpdf(pair_hmm, fulllogpdf(ξ, t, Y, Z))

-825.0319721486123

In [3]:
alignment = backward_sampling(pair_hmm.α, model)
display(alignment.data)
N = sequence_lengths(alignment)[1]
allowed = [[1],[1],[1]]
mask(alignment, allowed)

3×155 Matrix{Integer}:
 1  1  1  1  1  1  1  1  1  1  1  1  0  …  1  1  1  1  1  0  0  0  0  0  1  1
 1  1  1  1  1  1  0  1  1  1  1  1  1     1  1  1  1  1  0  0  0  0  0  1  0
 1  1  1  1  1  1  1  1  1  1  1  1  0     1  1  1  1  1  1  1  1  1  1  1  1

155-element BitVector:
 1
 1
 1
 1
 1
 1
 0
 1
 1
 1
 1
 1
 0
 ⋮
 1
 1
 1
 1
 1
 0
 0
 0
 0
 0
 1
 0

In [None]:
using Turing 
using ReverseDiff
#Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Turing.setadbackend(:forwarddiff)
Turing.setprogress!(true)

@model function bisection_sampler_continuous(y, z, p, t)
    y ~ statdist(p)
    x ~ transdist(p, t, y)
    z ~ transdist(p, t, x)
end

p = WrappedDiffusion(-1.0, -0.8, 0.1, 0.1, 1.0, 1.0, 0.2)
y = rand(statdist(p))
z = rand(statdist(p))
t = 0.2

model = bisection_sampler_continuous(y, z, p, t)
ch = sample(model, HMC(0.03, 5, :x), 10)


In [None]:


@model function ancestor_sampling(alignment, dataY, dataZ, t, ξ)
    C = num_coords(ξ)
    E = num_regimes(ξ)
    
    
    #dataX = 
    

    X_mask = mask(alignment, [[1], [0,1], [0,1]])
    alignmentX = alignment[X_mask]
    Y_mask = mask(alignment, [[0,1], [1], [0,1]])
    alignmentY = alignment[Y_mask]
    Z_mask = mask(alignment, [[0,1], [0,1], [1]])
    alignmentZ = alignment[Z_mask]
    
    M = length(alignment)
    regimes = tzeros(Int, M)
    regimes ~ filldist(Categorical(weights(ξ)), M)
    regimesX = regimes[X_mask]
    regimesY = regimes[Y_mask]
    regimesZ = regimes[Z_mask]
    
    X_maskX = mask(alignmentX, [[1], [0], [0]])
    
    XY_maskX = mask(alignmentX, [[1], [1], [0]])
    XY_maskY = mask(alignmentY, [[1], [1], [0]])
    
    XZ_maskX = mask(alignmentX, [[1], [0], [1]])
    XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])
    
    XYZ_maskX = mask(alignmentX, [[1], [1], [1]])
    XYZ_maskY = mask(alignmentY, [[1], [1], [1]])
    XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])
    
    for c ∈ 1:C, e ∈ 1:E
        p = processes(ξ)[c, e]
        # [1, 0, 0] - sample from stationary distribution
        data[c][:, X_maskX .& regimesX .== e] .~ statdist(p)
        
        # [1, 1, 0] - observe Y, then sample X from Y
        dataY[c][:, XY_maskY .& regimesY .== e] .~ statdist(p)
        dataX[c][:, XY_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataY[c][:, XY_maskY .& regimesY .== e])))
        
        # [1, 0, 1] - observe Z, then sample X from Z
        dataY[c][:, XZ_maskZ .& regimesZ .== e] .~ statdist(p)
        dataX[c][:, XZ_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataZ[c][:, XY_maskZ .& regimesZ .== e])))
        
        # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X
        dataY[c][:, XYZ_maskY .& regimesY .== e] .~ statdist(p)
        dataX[c][:, XYZ_maskX .& regimesX .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataY[c][:, XYZ_maskY .& regimesY .== e])))
        dataZ[c][:, XYZ_maskZ .& regimesZ .== e] ~ arraydist(transdist.(Ref(p), Ref(t), eachcol(dataX[c][:, XYZ_maskX .& regimesX .== e])))
    end
end
    