## 1. Load Data

In [1]:
using TorusEvol
using Distributions
using TimerOutputs

chainX = from_pdb("1A3N", "A")
chainY = from_pdb("1MBN", "A")
X = data(chainX)
Y = data(chainY)

data(X)[2][isnan.(data(X)[2])] .= 0.0
data(Y)[2][isnan.(data(Y)[2])] .= 0.0

my_diff = jumping(WrappedDiffusion(-1.0, -0.8, 0.1, 0.1, 1.0, 1.0, 0.2), 10.0)
my_diff2 = jumping(WrappedDiffusion(-1.0, -0.8, 2.0, 2.0, 1.0, 1.0, 0.1), 10.0)

w = WAG_SubstitutionProcess()
processes = [w w; my_diff my_diff2]
weights = [0.8, 0.2]
ξ = MixtureProductProcess(weights, processes)

t = 1.58
λ = 0.03; μ = 0.032; r = 0.3
align_model = TKF92([t], λ, μ, r)
pair_hmm = PairDataHMM(align_model, num_sites(X), num_sites(Y));

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling TorusEvol [4b860a26-b3bc-4b38-a9ed-83c1dc5d19b0]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1A3N
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFile exists: 1MBN


## 2. Inference Model

In [2]:
using Turing 
using ReverseDiff
using LogExpFunctions
using FastLogSumExp
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
#Turing.setadbackend(:forwarddiff)
Turing.setprogress!(true)

@model function pair_param_inference(pairs; E=1)
    # Sample branch length a.k.a. time
    t ~ Exponential(1.0)
    
    
    # Sample alignment parameters 
    # todo - do this in a stand-alone submodel
    #λ_a ~ Exponential(0.01)
    λ_a = 0.03
    #seq_length ~ Exponential(100)
    seq_inv_length = 1 / 145
    μ_a = (seq_inv_length+1) * λ_a
    r_a = 0.7
    align_model = TKF92([t], λ_a, μ_a, r_a)
    
    
    # Sample site-level evolutionary processes 
    
    # Sample the weight of each regime
    #proc_weights ~ Dirichlet(E, 1.0)
    proc_weights=[1.0]
    
    # Sample aminoacid substitution parameters
    # todo - construct custom substitution process
    sub_procs = reshape(fill(WAG_SubstitutionProcess(), E), 1, E)
    
    # Sample dihedral angle evolution parameters
    #todo - do this in a stand-alone submodel
    μ_𝜙 = Vector(undef, E); μ_𝜓 = Vector(undef, E)
    σ_𝜙 = Vector(undef, E); σ_𝜓 = Vector(undef, E)
    for e ∈ 1:E 
        pp = 0
        # μ_𝜙[e] ~ Uniform(-π, π)
        # μ_𝜓[e] ~ Uniform(-π, π)
        #σ_𝜙[e] ~ Exponential(0.5)
        #σ_𝜓[e] ~ Exponential(0.5)
    end
    μ_𝜙 = fill(0.3, E); μ_𝜓 = fill(0.3, E)
    σ_𝜙 = fill(0.3, E); σ_𝜓 = fill(0.3, E)
    α_𝜙 = fill(1.0, E); α_𝜙 = fill(1.0, E); α_cov = fill(0.1, E)
    γ = fill(20.0, E)
    diff_procs = reshape(jumping.(WrappedDiffusion.(μ_𝜙, μ_𝜓, σ_𝜙, σ_𝜓, α_𝜙, α_𝜙, α_cov), γ), 1, E)
    #diff_procs = reshape(WrappedDiffusion.(μ_𝜙, μ_𝜓, σ_𝜙, σ_𝜓, α_𝜙, α_𝜙, α_cov), 1, E)
            
    # Construct mixture product process
    ξ = MixtureProductProcess(proc_weights, vcat(sub_procs, diff_procs))
    #ξ = MixtureProductProcess(proc_weights, sub_procs)

    # Observe X, Y by proxy of their joint probabilities 
    for (X, Y) ∈ pairs
        emission_lps = fulllogpdf(ξ, t, X, Y)
        pair_hmm = PairDataHMM(align_model, num_sites(X), num_sites(Y))
        Turing.@addlogprob! logpdf(pair_hmm, emission_lps)
        #Turing.@addlogprob! logsumexp(emission_lps)
    end
end

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[Turing]: progress logging is enabled globally
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[AdvancedVI]: global PROGRESS is set as true


pair_param_inference (generic function with 2 methods)

In [None]:
using StatsPlots
data_size = 3
Xs = fill(X, data_size)
Ys = fill(Y, data_size)
pairs = zip(Xs, Ys)
model = pair_param_inference(pairs)
alg = Gibbs(HMC(0.1, 5, :t))

hmc1 = Gibbs(HMC(0.02, 5, :t))
mh1 = Gibbs(MH(:t), MH(:r_a), MH(:λ_a))
        
algmh = Gibbs(MH(:t), MH(:λ_a), MH(:μ_𝜙), MH(:μ_𝜓), MH(:r_a))
numsamples = 10
reset_timer!(to)
ch = sample(model, hmc1, numsamples)

In [None]:
print(to)

In [None]:
p = plot(ch, fontfamily="JuliaMono")
#savefig(p, "images/pairinfer2mh.svg")

In [None]:
p = plot(get(ch, :lp).lp, xlabel="Iterations", label="Sequence joint log probability")
#savefig(p, "images/pairinferlp2mh.svg")