## 1. Prepare data

### 1.1 Load real data example

In [10]:
using TorusEvol
using Distributions

# Underlying evolutionary process
t_Y = 0.5; t_Z=0.3; t_W=0.9
Œª=0.03; Œº=0.0308; r=0.4
œÑ = TKF92([t_Y+t_Z], Œª, Œº, r)
S = WAG_SubstitutionProcess()
Œº_ùúô=-1.0; Œº_ùúì=-0.8; œÉ_ùúô=0.8; œÉ_ùúì=0.8; Œ±_ùúô=0.5; Œ±_ùúì=1.0; Œ±_cov=0.1; Œ≥=0.2
Œò = JumpingWrappedDiffusion(Œº_ùúô, Œº_ùúì, œÉ_ùúô, œÉ_ùúì, Œ±_ùúô, Œ±_ùúì, Œ±_cov, Œ≥)
Œæ = ProductProcess(S, Œò)
Œì = ChainJointDistribution(Œæ, œÑ)

chainY = from_pdb("1A3N", "A"); Y = data(chainY)
chainZ = from_pdb("1MBN", "A"); Z = data(chainZ)
chainW = from_pdb("1A3N", "B"); W = data(chainW)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1MBN
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N


LoadError: UndefVarError: `LED` not defined

In [6]:
œÑ_XYZ = TKF92([t_Y, t_Z], Œª, Œº, r; known_ancestor=false)
Œì = ChainJointDistribution(Œæ, œÑ_XYZ)
Œ±_YZ = get_Œ±(œÑ_XYZ, (Y, Z))
lp = logpdfŒ±!(Œ±_YZ, Œì, (Y, Z)); print(lp)

c = ConditionedAlignmentDistribution(œÑ_XYZ, Œ±_YZ)
M_XYZ_data = rand(ConditionedAlignmentDistribution(œÑ_XYZ, Œ±_YZ)); M_XYZ = Alignment(M_XYZ_data)
print(logpdf(c, M_XYZ_data))
data(M_XYZ)

-1490.1530021439378-24.747921659483893

3√ó162 Matrix{Integer}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  ‚Ä¶  1  1  1  1  1  0  0  0  0  0  0  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  0  0  0  0  0  0  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1

In [8]:
#X = hiddenchain_from_alignment(Y, Z, t_Y, t_Z, M_XYZ, Œæ)
LengthEquilibriumDefinition(Œª, Œº, r)

LoadError: UndefVarError: `LengthEquilibriumDefinition` not defined

## 2. Parameter Inference Bayesian Model

### 2.1 Set up priors for evolutionary processes

In [None]:
using Turing, DynamicPPL
using LinearAlgebra
using LogExpFunctions 
using Plots, StatsPlots
using Random

import Base: length, eltype
import Distributions: _rand!, logpdf

Turing.setprogress!(true)

struct ScaledBeta <: ContinuousUnivariateDistribution 
    be::Beta 
    function ScaledBeta(Œ±::Real, Œ≤::Real)
        new(Beta(Œ±, Œ≤))
    end
end
Distributions.rand(rng::AbstractRNG, d::ScaledBeta) = rand(d.be)*2 - 1
Distributions.logpdf(d::ScaledBeta, x::Real) = logpdf(d.be, (x+1) / 2)


struct CompetingExponential <: ContinuousMultivariateDistribution 
    ex::Exponential
    function CompetingExponential(rate::Real)
        new(Exponential(rate))
    end
end 
Base.eltype(d::CompetingExponential) = Float64 
Base.length(d::CompetingExponential) = 2

function Distributions._rand!(rng::AbstractRNG, d::CompetingExponential, x::AbstractVector{<:Real})
    Œª = rand(rng, d.ex)
    Œº = rand(rng, d.ex)
    if Œª > Œº 
        tmp = Œª; Œª = Œº; Œº=tmp 
    end
    x .= [Œª, Œº]
    return x
end

function Distributions._logpdf(d::CompetingExponential, x::AbstractArray)
    if x[1] > x[2]
        return -Inf
    end
    return log(2) + logpdf(d.ex, x[1]) + logpdf(d.ex, x[2])
end

@model function tkf92_prior()
    ŒªŒº ~ CompetingExponential(1.0)
    Œª = ŒªŒº[1]; Œº = ŒªŒº[2]
    r ~ Uniform(0.0, 1.0)

    # Require birth rate lower than death rate
    if Œª > Œº || Œª ‚â§ 0 || Œº ‚â§ 0 || r ‚â§ 0 || r ‚â• 1
        Œº = NaN; Œª = NaN
    end
    return Œª, Œº, r
end;

@model function jwndiff_prior()
    Œº ~ filldist(Uniform(-œÄ, œÄ), 2)
    œÉ¬≤ ~ filldist(Gamma(œÄ * 0.1), 2)
    Œ± ~ filldist(Gamma(œÄ * 0.1), 2)
    Œ≥ ~ Exponential(1.0)   # jumping rate
    Œ±_corr ~ ScaledBeta(3, 3)
    
    # Require valid covariance matrices
    if any(œÉ¬≤ .‚â§ 0) || any(Œ± .‚â§ 0) || Œ≥ ‚â§ 0 
        œÉ¬≤ .= NaN; Œ± .= NaN; Œ≥ = NaN
    end
    Œ±_cov = Œ±_corr * sqrt(Œ±[1] * Œ±[2])
    if Œ±_cov^2 > Œ±[1]*Œ±[2]
         Œ±_cov = NaN
    end
    
    return Œº[1], Œº[2], sqrt(œÉ¬≤[1]), sqrt(œÉ¬≤[2]), Œ±[1], Œ±[2], Œ±_cov, Œ≥
end;

### 2.2 Set up sampler

In [None]:
torus_proposal(v) = MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2])
mv_rw_proposal(v::AbstractVector, cov) = MvNormal(v, cov)
rw_proposal(x, var) = Normal(x, var)


sampler = Gibbs(MH(:t => v -> rw_proposal(v, 0.2)),
                MH(Symbol("Œò.Œº") => v -> torus_proposal(v)),
                MH(Symbol("Œò.œÉ¬≤") => v -> mv_rw_proposal(v, 0.4*I)),
                MH(Symbol("Œò.Œ±") => v -> mv_rw_proposal(v, 0.4*I)),
                MH(Symbol("Œò.Œ±_corr") => x -> rw_proposal(x, 0.5)),
                MH(Symbol("Œò.Œ≥") => x -> rw_proposal(x, 0.5)),
                MH(Symbol("œÑ.ŒªŒº") => v -> mv_rw_proposal(v, [0.4 0.1; 0.1 0.6])),
                MH(Symbol("œÑ.r") => x -> rw_proposal(x, 0.5))
               );

In [None]:
using Memoization 

@memoize get_Œ±s(pairs) = get_Œ±.(Ref(TKF92([1.0], 0.2, 0.3, 0.4)), pairs)
@memoize get_Bs(pairs) = get_B.(pairs)

### 2.3 Prepare probabilistic model

In [None]:
using TimerOutputs

@model function pair_param_inference_simple(pairs)
    # ____________________________________________________________________________________________________
    # Step 1 - Sample prior parameters
    
    # Time parameter
    t ~ Exponential(1.0) 
    # Alignment parameters
    @submodel prefix="œÑ" Œõ = tkf92_prior()
    # Dihedral parameters 
    @submodel prefix="Œò" Œû = jwndiff_prior()
    # Check parameter validity 
    if t ‚â§ 0 || any(isnan.(Œû)) || any(isnan.(Œõ))
        Turing.@addlogprob! -Inf; return
    end
    
    # ____________________________________________________________________________________________________
    # Step 2 - Construct processes 
    
    # Substitution Process - no parameters for simplicity, use fully empirical model
    S = WAG_SubstitutionProcess()
    # Dihedral Process
    Œò = JumpingWrappedDiffusion(Œû...)
    # Joint sequence-structure site level process with one regime
    Œæ = MixtureProductProcess([1.0], hcat([S, Œò]))
    
    # Alignment model
    œÑ = TKF92([t], Œõ...)
    
    # Chain level model
    Œì = ChainJointDistribution(Œæ, œÑ)
    
    # ____________________________________________________________________________________________________
    # Step 3 - Observe each pair X, Y by proxy of their joint probability, marginalising over alignments
    Œ± = get_Œ±s(pairs)
    B = get_Bs(pairs)
    for i ‚àà eachindex(pairs)
        X, Y = pairs[i]
        # (X, Y) ~ ChainJointDistribution(Œæ, œÑ)
        fulljointlogpdf!(B[i], Œæ, t, X, Y)
        Turing.@addlogprob! logpdfŒ±B!(Œ±[i], B[i], Œì, (X, Y))
    end
        
    return Œì
end;

### 2.4 Sample from the model and check results

In [None]:
num_samples = 200
num_chains = 3
model = pair_param_inference_simple(simulated_data)

In [None]:
chain = sample(model, sampler, MCMCThreads(), num_samples, num_chains)
p = plot(chain, fontfamily="JuliaMono")

In [None]:
chain