## 1. Prepare data

### 1.1 Load real data example

In [16]:
using TorusEvol
using Distributions

# Underlying evolutionary process
t_Y = 0.3; t_Z=0.2; t_W=0.2
λ=0.02; μ=0.0203; r=0.4
τ = TKF92([t_Y+t_Z], λ, μ, r)
S = WAG_SubstitutionProcess()
μ_𝜙=-1.0; μ_𝜓=-0.8; σ_𝜙=0.8; σ_𝜓=0.8; α_𝜙=0.5; α_𝜓=1.0; α_cov=0.1; γ=0.2
Θ = JumpingWrappedDiffusion(μ_𝜙, μ_𝜓, σ_𝜙, σ_𝜓, α_𝜙, α_𝜓, α_cov, γ)
ξ = ProductProcess(S, Θ)
Γ = ChainJointDistribution(ξ, τ)

chainY = from_pdb("1MBN", "A"); Y = slice(data(chainY), :, [1, 2])
chainZ = from_pdb("1A3N", "A"); Z = slice(data(chainZ), :, [1, 2])
chainW = from_pdb("1A3N", "B"); W = slice(data(chainW), :, [1, 2])

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1MBN
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N


ObservedChain(
num sites: 145
num coords: 2
)

In [17]:
τ_XYZ = TKF92([t_Y, t_Z], λ, μ, r; known_ancestor=false)
Γ = ChainJointDistribution(ξ, τ_XYZ)
α_YZ = get_α(τ_XYZ, (Y, Z))
lp = logpdfα!(α_YZ, Γ, (Y, Z)); print(lp)

c = ConditionedAlignmentDistribution(τ_XYZ, α_YZ)
M_XYZ_data = rand(ConditionedAlignmentDistribution(τ_XYZ, α_YZ)); M_XYZ = Alignment([1, 2, 3], M_XYZ_data);

-1514.2742268247348

In [19]:
using LogExpFunctions
X = hiddenchain_from_alignment(Y, Z, t_Y, t_Z, M_XYZ, ξ)
α_XW = get_α(TKF92([t_W], λ, μ, r; known_ancestor=true), [X, W])
τ_XW = TKF92([t_W], λ, μ, r; known_ancestor=true)
C_XW = ChainTransitionDistribution(ξ, τ_XW, X)
logpdfα!(α_XW, C_XW, W)
cxw = ConditionedAlignmentDistribution(τ_XW, α_XW)
M_XW = Alignment([1, 4], rand(cxw))
M_XYZW = combine(1, M_XYZ, M_XW)
    
M_YZW = subalignment(M_XYZW, [2, 3, 4])
show_filled_alignment(M_YZW, id_to_aa.(data(Y)[1]), id_to_aa.(data(Z)[1]), id_to_aa.(data(W)[1]))


-VLSEGEWQLVLHVWAKVE-ADVAGHGQDILIRLFKSHPETLEKFDRF-KHLK--TEAEMKA----SEDLKKHGV
V-LSPADKTNVKAAWGKV-GAHAGEYGAEALERMFLSFPTTKTYFP----HF---D------LSHGSAQVKGHGK
H-LTPEEKSAVTALWGK---VNVDEVGGEALGRLLVVYPWTQRFFESFG--D-LSTPDAVM---G-NPKVKAHGK

TVLTAL-GAILKKKGHH---EAELKPLAQSHATKHKIPIKYLEFISEAIIHVLH-SRHPGDFGADAQGAMNKALE
KVADALTNAVA----HVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLA-AHLPAEFTPAVHASLDKFLA
KVLGAF-SD--GLA-HLDNLKGTFATLSELHCDKLHVDPENFRLLGNVLVCVLAHHFGK-EFTPPVQAAYQKVVA

LFRKDIAAK-YKELGYQG
SVSTVLTSK-Y-----R-
GVANALAHKY------H-



## 2. Parameter Inference Bayesian Model

### 2.1 Set up priors for evolutionary processes

In [15]:
using Turing, DynamicPPL
using LinearAlgebra
using LogExpFunctions 
using Plots, StatsPlots
using Random

import Base: length, eltype
import Distributions: _rand!, logpdf

Turing.setprogress!(true)

struct ScaledBeta <: ContinuousUnivariateDistribution 
    be::Beta 
    function ScaledBeta(α::Real, β::Real)
        new(Beta(α, β))
    end
end
Distributions.rand(rng::AbstractRNG, d::ScaledBeta) = rand(d.be)*2 - 1
Distributions.logpdf(d::ScaledBeta, x::Real) = logpdf(d.be, (x+1) / 2)


struct CompetingExponential <: ContinuousMultivariateDistribution 
    ex::Exponential
    function CompetingExponential(rate::Real)
        new(Exponential(rate))
    end
end 
Base.eltype(d::CompetingExponential) = Float64 
Base.length(d::CompetingExponential) = 2

function Distributions._rand!(rng::AbstractRNG, d::CompetingExponential, x::AbstractVector{<:Real})
    λ = rand(rng, d.ex)
    μ = rand(rng, d.ex)
    if λ > μ 
        tmp = λ; λ = μ; μ=tmp 
    end
    x .= [λ, μ]
    return x
end

function Distributions._logpdf(d::CompetingExponential, x::AbstractArray)
    if x[1] > x[2]
        return -Inf
    end
    return log(2) + logpdf(d.ex, x[1]) + logpdf(d.ex, x[2])
end

@model function tkf92_prior()
    λμ ~ CompetingExponential(1.0)
    λ = λμ[1]; μ = λμ[2]
    r ~ Uniform(0.0, 1.0)

    # Require birth rate lower than death rate
    if λ > μ || λ ≤ 0 || μ ≤ 0 || r ≤ 0 || r ≥ 1
        μ = NaN; λ = NaN
    end
    return λ, μ, r
end;

@model function jwndiff_prior()
    μ ~ filldist(Uniform(-π, π), 2)
    σ² ~ filldist(Gamma(π * 0.1), 2)
    α ~ filldist(Gamma(π * 0.1), 2)
    γ ~ Exponential(1.0)   # jumping rate
    α_corr ~ ScaledBeta(3, 3)
    
    # Require valid covariance matrices
    if any(σ² .≤ 0) || any(α .≤ 0) || γ ≤ 0 
        σ² .= NaN; α .= NaN; γ = NaN
    end
    α_cov = α_corr * sqrt(α[1] * α[2])
    if α_cov^2 > α[1]*α[2]
         α_cov = NaN
    end
    
    return μ[1], μ[2], sqrt(σ²[1]), sqrt(σ²[2]), α[1], α[2], α_cov, γ
end;

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[Turing]: progress logging is enabled globally
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[AdvancedVI]: global PROGRESS is set as true


### 2.2 Set up sampler

In [None]:
torus_proposal(v) = MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2])
mv_rw_proposal(v::AbstractVector, cov) = MvNormal(v, cov)
rw_proposal(x, var) = Normal(x, var)


sampler = Gibbs(MH(:t => v -> rw_proposal(v, 0.2)),
                MH(Symbol("Θ.μ") => v -> torus_proposal(v)),
                MH(Symbol("Θ.σ²") => v -> mv_rw_proposal(v, 0.4*I)),
                MH(Symbol("Θ.α") => v -> mv_rw_proposal(v, 0.4*I)),
                MH(Symbol("Θ.α_corr") => x -> rw_proposal(x, 0.5)),
                MH(Symbol("Θ.γ") => x -> rw_proposal(x, 0.5)),
                MH(Symbol("τ.λμ") => v -> mv_rw_proposal(v, [0.4 0.1; 0.1 0.6])),
                MH(Symbol("τ.r") => x -> rw_proposal(x, 0.5))
               );

In [None]:
using Memoization 

@memoize get_αs(pairs) = get_α.(Ref(TKF92([1.0], 0.2, 0.3, 0.4)), pairs)
@memoize get_Bs(pairs) = get_B.(pairs)

### 2.3 Prepare probabilistic model

In [None]:
using TimerOutputs

@model function pair_param_inference_simple(pairs)
    # ____________________________________________________________________________________________________
    # Step 1 - Sample prior parameters
    
    # Time parameter
    t ~ Exponential(1.0) 
    # Alignment parameters
    @submodel prefix="τ" Λ = tkf92_prior()
    # Dihedral parameters 
    @submodel prefix="Θ" Ξ = jwndiff_prior()
    # Check parameter validity 
    if t ≤ 0 || any(isnan.(Ξ)) || any(isnan.(Λ))
        Turing.@addlogprob! -Inf; return
    end
    
    # ____________________________________________________________________________________________________
    # Step 2 - Construct processes 
    
    # Substitution Process - no parameters for simplicity, use fully empirical model
    S = WAG_SubstitutionProcess()
    # Dihedral Process
    Θ = JumpingWrappedDiffusion(Ξ...)
    # Joint sequence-structure site level process with one regime
    ξ = MixtureProductProcess([1.0], hcat([S, Θ]))
    
    # Alignment model
    τ = TKF92([t], Λ...)
    
    # Chain level model
    Γ = ChainJointDistribution(ξ, τ)
    
    # ____________________________________________________________________________________________________
    # Step 3 - Observe each pair X, Y by proxy of their joint probability, marginalising over alignments
    α = get_αs(pairs)
    B = get_Bs(pairs)
    for i ∈ eachindex(pairs)
        X, Y = pairs[i]
        # (X, Y) ~ ChainJointDistribution(ξ, τ)
        fulljointlogpdf!(B[i], ξ, t, X, Y)
        Turing.@addlogprob! logpdfαB!(α[i], B[i], Γ, (X, Y))
    end
        
    return Γ
end;

### 2.4 Sample from the model and check results

In [None]:
num_samples = 200
num_chains = 3
model = pair_param_inference_simple(simulated_data)

In [None]:
chain = sample(model, sampler, MCMCThreads(), num_samples, num_chains)
p = plot(chain, fontfamily="JuliaMono")

In [None]:
chain