## 1. Experiment set up

### 1.1. Sample evolutionary process parameters

In [53]:
using TorusEvol 
using Distributions
using Turing
using Random

Random.seed!(10203)
Turing.setprogress!(true)

# Number of descendants 
D = 10
Y = Vector{ObservedChain}(undef, D) # descendants
M = Vector{Alignment}(undef, D) # alignments

# Evolutionary regimes
E = 2
weights = rand(Dirichlet(E, 1.0))

# Site level process
S = WAG_SubstitutionProcess()

Ξ_1 = jwndiff_prior()(); Θ_1 = JumpingWrappedDiffusion(Ξ_1...); @info Ξ_1 "Ξ_1"
Ξ_2 = jwndiff_prior()(); Θ_2 = JumpingWrappedDiffusion(Ξ_2...); @info Ξ_2 "Ξ_2"
ξ = MixtureProductProcess(weights, [S S; Θ_1 Θ_2])

# Alignment parameters
Λ = tkf92_prior()(); @info Λ "Λ"

# Evolutionary distances 
ts = rand(Exponential(0.1), D) 
ts

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[Turing]: progress logging is enabled globally
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[AdvancedVI]: global PROGRESS is set as true
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m(2.0793844375569694, 0.005773570704397102, 0.009389661226149627, 0.769318191229367, 0.0004517633068183017, 0.3867126231221906, 0.0008207706544378668, 0.8074528459487708)
[36m[1m└ [22m[39m  Ξ_1 = "Ξ_1"
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m(1.1344905145087365, -0.6151056108872228, 0.5045661989378393, 1.1942245564097882, 0.001300041441663701, 0.02682643626914298, 0.0005803159622610678, 1.3748720746700442)
[36m[1m└ [22m[39m  Ξ_2 = "Ξ_2"
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m(0.029162685377525838, 0.061302174093350016, 0.5520709365048283)
[36m[1m└ [22m[39m  Λ = "Λ"


10-element Vector{Float64}:
 0.31426686000171133
 0.15546045409280682
 0.06330746782915772
 0.11425781020314718
 0.1902670150680404
 0.1413388813018777
 0.3871452754000275
 0.04522098943588153
 0.04587982654249431
 0.01649363662998924

### 1.2. Sample ancestor and aligned descendants

In [61]:
pair_chain_dist = ChainJointDistribution(ξ, TKF92([ts[1]], Λ...))
(X, _) = rand(pair_chain_dist)

for d ∈ 1:D
    τ = TKF92([ts[d]], Λ...)
    Y[d] = rand(ChainTransitionDistribution(ξ, τ, X))
    Γ = ChainJointDistribution(ξ, τ)
    α_XY = get_α(τ, (X, Y[d])); logpdfα!(α_XY, Γ, (X, Y[d]))
    M[d] = Alignment(rand(ConditionedAlignmentDistribution(TKF92([ts[d]], Λ...), α_XY)), τ)
    M[d] = Alignment(data(M[d]), [1, d+1])
end

chainX = from_primary_dihedrals(Int.(data(X)[1]), data(X)[2])
chainY = from_primary_dihedrals(Int.(data(Y[1])[1]), data(Y[1])[2])
lp = logpdf(pair_chain_dist, (X, Y[1]))
print("The log pdf of X and Y_1 is $lp")
render(chainX, chainY; aligned=true)

M_full = combine(1, M[1], M[2])
for d ∈ 3:D 
    M_full = combine(1, M_full, M[d])
end

@info M_full

The log pdf of X and Y_1 is -83.51918048243054

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mSuperimposing based on a sequence alignment between 7 residues
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mSuperimposing based on 7 atoms
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mModel 1 with 2 chains (1,2), 14 residues, 54 atoms
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m
[36m[1m│ [22m[39m-#---#----##-#----#-#
[36m[1m│ [22m[39m-####--------#----#-#
[36m[1m│ [22m[39m-#---#----##-#----#-#
[36m[1m│ [22m[39m-#---##----##-----#-#
[36m[1m│ [22m[39m-#---#-#---#-#----#-#
[36m[1m│ [22m[39m#----#----##-#----#-#
[36m[1m│ [22m[39m-#------------#####-#
[36m[1m│ [22m[39m--------####-#----##-
[36m[1m│ [22m[39m-#---#----##-#----#-#
[36m[1m│ [22m[39m-#---#----##-#----#-#
[36m[1m│ [22m[39m-#---#----##-#----#-#
[36m[1m│ [22m[39m
[36m[1m└ [22m[39m


In [None]:
@model function bisection_sampler(p, descs, t)
    D = length(descs)
    descs[1] ~ statdist(p)
    
    x = Vector{Real}(undef, length(p))
    x ~ transdist(p, t, descs[1])
    for d ∈ 2:D
        descs[d] ~ transdist(p, t, x)
    end
end

function sample_anc_coords_wn(p, descs, t; burn_in=300)
    sampler = MH(:x => v -> MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2]))
    n = size(descs[1], 2)
    D = length(descs)
    X = Matrix{Real}(undef, length(p), n)
    for i ∈ 1:n
        descs_i = [descs[d][:, i] for d ∈ 1:D]
        model = bisection_sampler(p, descs_i, t)
        chn = sample(model, sampler, burn_in+1)
        X[1, i] = get(chn, :x).x[1][burn_in+1]
        X[2, i] = get(chn, :x).x[2][burn_in+1]
    end
    return X
end

function sample_anc_coords_sub(p, descs, t; burn_in=300)
    sampler = PG(60, :x)
    n = size(descs[1], 2)
    D = length(descs)
    X = Matrix{Real}(undef, length(p), n)
    for i ∈ 1:n
        descs_i = [descs[d][:, i] for d ∈ 1:D]
        model = bisection_sampler(p, descs_i, t)
        chn = sample(model, sampler, burn_in+1)
        X[1, i] = get(chn, :x).x[1][burn_in+1]
    end
    return X
end

function sample_anc_coords(p, descs, t)
    if eltype(p) <: Integer 
        return sample_anc_coords_sub(p, descs, t)
    else 
        return sample_anc_coords_wn(p, descs, t)
    end
end

function ancestor_sampling(M_YZ::Alignment, Y::ObservedChain, Z::ObservedChain, 
                           t::Real, ξ::MixtureProductProcess)
    # step 1 - sample XYZ alignment
    M_XYZ = sample_anc_alignment(M_YZ, Y, Z, t, ξ)
    
    # step 2 - sample coordinates of X
    alignment = M_XYZ
    X_mask = mask(alignment, [[1], [0,1], [0,1]])
    alignmentX = slice(alignment, X_mask)
    Y_mask = mask(alignment, [[0,1], [1], [0,1]])
    alignmentY = slice(alignment, Y_mask)
    Z_mask = mask(alignment, [[0,1], [0,1], [1]])
    alignmentZ = slice(alignment, Z_mask)
    
    M = length(alignment)
    regimes = ones(M)
    
    X_maskX = mask(alignmentX, [[1], [0], [0]])
    
    XY_maskX = mask(alignmentX, [[1], [1], [0]])
    XY_maskY = mask(alignmentY, [[1], [1], [0]])
    
    XZ_maskX = mask(alignmentX, [[1], [0], [1]])
    XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])
    
    XYZ_maskX = mask(alignmentX, [[1], [1], [1]])
    XYZ_maskY = mask(alignmentY, [[1], [1], [1]])
    XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])
    
    dataY = data(Y)
    dataZ = data(Z)
    
    # Initialise internal coordinates of X
    N = sequence_lengths(M_XYZ)[1]
    dataX = [similar(dataY[c], size(dataY[c], 1), N) for c ∈ 1:C]
    
    for c ∈ 1:C, e ∈ 1:E
        p = processes(ξ)[c, e]
        
        # [1, 0, 0] - sample from stationary distribution
        dataX100 = @view dataX[c][:, X_maskX .& regimesX .== e]
        n100 = size(dataX100, 2)
        
        dataX100 .= rand(statdist(p), n100)
        
        
        # [1, 1, 0] - observe Y, then sample X from Y
        dataY110 = @view dataY[c][:, XY_maskY .& regimesY .== e]
        dataX110 = @view dataX[c][:, XY_maskX .& regimesX .== e]
        
        dataX110 .= sample_anc_coords(p, [dataY110], t)
    
        
        # [1, 0, 1] - observe Z, then sample X from Z
        dataZ101 = @view dataZ[c][:, XZ_maskZ .& regimesZ .== e]
        dataX101 = @view dataX[c][:, XZ_maskZ .& regimesX .== e]
        
        dataX101 .= sample_anc_coords(p, [dataZ101], t)
        
        
        # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X
        dataY111 = @view dataY[c][:, XYZ_maskY .& regimesY .== e]
        dataZ111 = @view dataZ[c][:, XYZ_maskZ .& regimesZ .== e]
        dataX111 = @view dataX[c][:, XYZ_maskX .& regimesX .== e]
        
        dataX111 .= sample_anc_coords(p, [dataY111, dataZ111], t)
    end
    
    X = ObservedChain(dataX)
end