In [None]:
using PulseInputDDM
using Parameters, MAT
using PyPlot

# Save backward pass

In [None]:
"""
    P, M, xc, dx = initialize_latent_model(σ2_i, B, λ, σ2_a, n, dt; lapse=0.)
Creates several variables that are required to compute the LL for each trial, but that
are identical for all trials.
It adds lapse mass to both bounds at the beginning 
of the trial, this helps produce faithful backward pass values

"""

function initialize_latent_model(σ2_i::TT, B::TT, λ::TT, σ2_a::TT,
    n::Int, dt::Float64, lapse::UU=0.) where {TT,UU <: Any}
    
    xc, dx = PulseInputDDM.bins(B,n)
    P = zeros(TT,n)
    P[ceil(Int,n/2)] = one(TT) - lapse
    P[1], P[n] = lapse/2., lapse/2.
    M = PulseInputDDM.transition_M(σ2_i,zero(TT),zero(TT),dx,xc,n,dt)
    P = M * P

    return P, M, xc, dx
    
end


function logsumexp(x)
    m = maximum(x)
    m + log(sum(exp.(x .- m)))
end


"""
"""
function ΣLR_ΔLR(t::Int, nL::Vector{Int}, nR::Vector{Int},
        La::Vector{TT}, Ra::Vector{TT}) where {TT <: Any}

    any(t .== nL) ? sL = sum(La[t .== nL]) : sL = zero(TT)
    any(t .== nR) ? sR = sum(Ra[t .== nR]) : sR = zero(TT)

    sL + sR, -sL + sR
    
end



function bing_backward(θ::θchoice, data::PulseInputDDM.choicedata, P::Vector{TT},
    M::Array{TT,2}, dx::UU, xc::Vector{TT}, n::Int, cross::Bool) where {TT, UU <: Real}

    @unpack click_data, choice = data
    @unpack binned_clicks, clicks, dt = click_data
    @unpack nT, nL, nR = binned_clicks
    @unpack L, R = clicks
    @unpack θz, bias = θ
    @unpack λ, σ2_a, σ2_s, ϕ, τ_ϕ = θz
    
    # FORWARD PASS TO STORE P(a_t|θ) for t = 1:N
    #adapt magnitude of the click inputs
    La, Ra = PulseInputDDM.adapt_clicks(ϕ,τ_ϕ,L,R; cross=cross)
    F = zeros(TT,n,n)
    α = Array{TT,2}(undef, n, nT)
    @inbounds for t = 1:nT
        P,F = PulseInputDDM.latent_one_step!(P,F,λ,σ2_a,σ2_s,t,nL,nR,La,Ra,M,dx,xc,n,dt)
        P /= sum(P)
        α[:,t] = P
    end

    # BACKWARD PASS TO COMPUTE P(a_t|d,θ) for t = 1:N
    # refer to Brunton '13 supplementary (section 3.2.3)
    β = Array{TT,2}(undef, n, nT)
    β[:,end] = PulseInputDDM.choice_likelihood!(bias,xc,P,choice,n,dx)
    β[:,end] /= sum(β[:, end])
    P = β[:,end]
    F = zeros(TT,n,n)
    
    @inbounds for t = nT-1:-1:1
        Σ, μ = ΣLR_ΔLR(t, nL, nR, La, Ra)
        σ2 = σ2_s * Σ
        PulseInputDDM.transition_M!(F,σ2+σ2_a*dt,λ, μ, dx, xc, n, dt)
        back = zeros(Float64,n,n)
        for j = 1:n
            # back[:,j] =  F[:,j] .* α[:,t] ./ α[j,t+1]
            # to deal with underflow 
            back[:,j] =  exp.(log.(F[:,j]) .+ log.(α[:,t] .+ 1e-200) .- log.(α[j,t+1] + 1e-200))
        end
        P = back *  P    
        # normalize while preventing underflow
        P = exp.(log.(P) .- logsumexp(log.(P)))  
        β[:,t] = P
    end
    
    return β
    
end


In [None]:
dt = 0.001
n = 53
cross = true

ratnames = ["X046", "X062", "X087", "A294", "A297"]
datapath = ENV["HOME"]* "/ondrive/analysisDG/PBups_Phys/manuscript/figure_code/saved_results/behavior_data/"
fitfile = [
    "X046_last120days_k8_bing_apr21_2023.mat",
    "X062_last120days_Ot_bing_apr21_2023.mat",
    "X087_last120days_Zy_bing_apr21_2023.mat",
    "A294_last120days_pw_bing_apr21_2023.mat",
    "A297_last120days_gd_bing_apr21_2023.mat"]

for i = 1:length(ratnames)

    print("\n Rat "*string(i)*" : "*ratnames[i])

    str = fitfile[i]
    print(fitfile[i])
    prmfile = filter(x->occursin(str,x), readdir(datapath))[1]
    θ, options = reload_choice_model(datapath*prmfile)
    @unpack bias = θ
    @unpack lapse_prob = θ.θlapse
    @unpack σ2_i, B, λ, σ2_a = θ.θz

    str = Regex(ratnames[i]*".*rawdata.mat")
    datafiles = filter(x->occursin(str,x),readdir(datapath))
    for d = 1:length(datafiles)
        data = load_choice_data(datapath*datafiles[d], dt = dt);
        @unpack dt = data[1].click_data

        P, M, xc, dx = initialize_latent_model(σ2_i, B, λ, σ2_a, n, dt, lapse_prob)
        backward = map(data ->  bing_backward(θ, data, P,M, dx, xc, n, cross), data) 
        dict = Dict("backward" => backward, "prmfile" => prmfile, "xc" => xc)
        matwrite(datapath*datafiles[d][1:end-4]*"_accum_backward.mat", dict)
    end

end

In [None]:
dt = 0.001
n = 53
cross = true

ratnames = ["X046", "X062", "X087", "A294", "A297"]
datapath = ENV["HOME"]* "/ondrive/analysisDG/PBups_Phys/manuscript/figure_code/saved_results/behavior_data/"
fitfile = [
    "X046_last120days_k8_bing_apr21_2023.mat",
    "X062_last120days_Ot_bing_apr21_2023.mat",
    "X087_last120days_Zy_bing_apr21_2023.mat",
    "A294_last120days_pw_bing_apr21_2023.mat",
    "A297_last120days_gd_bing_apr21_2023.mat"]

i = 5
print("\n Rat "*string(i)*" : "*ratnames[i])

str = fitfile[i]
print(fitfile[i])
prmfile = filter(x->occursin(str,x), readdir(datapath))[1]
θ, options = reload_choice_model(datapath*prmfile)
@unpack bias = θ
@unpack lapse_prob = θ.θlapse
@unpack σ2_i, B, λ, σ2_a = θ.θz

str = Regex(ratnames[i]*".*rawdata.mat")
datafiles = filter(x->occursin(str,x),readdir(datapath))
d = 1
data = load_choice_data(datapath*datafiles[d], dt = dt);
@unpack dt = data[1].click_data




In [None]:
tr = 15
P, M, xc, dx = initialize_latent_model(σ2_i, B, λ, σ2_a, n, dt, lapse_prob)
backward = bing_backward(θ, data[tr], P,M, dx, xc, n, cross);

plt.imshow(backward, aspect = "auto", vmax = 0.1)
PyPlot.display_figs()
println(data[tr].click_data.clicks.T)

In [None]:
println(data[tr].click_data.clicks)
println(data[tr].choice)

In [None]:
dt = 0.001
n = 53
cross = true

ratnames = ["X046", "X062", "X087", "A294", "A297"]
datapath = ENV["HOME"]* "/ondrive/analysisDG/PBups_Phys/manuscript/figure_code/saved_results/behavior_data/"
fitfile = [
    "X046_last120days_k8_bing_apr21_2023.mat",
    "X062_last120days_Ot_bing_apr21_2023.mat",
    "X087_last120days_Zy_bing_apr21_2023.mat",
    "A294_last120days_pw_bing_apr21_2023.mat",
    "A297_last120days_gd_bing_apr21_2023.mat"]


i = 1

print("\n Rat "*string(i)*" : "*ratnames[i])

str = fitfile[i]
print(fitfile[i])
prmfile = filter(x->occursin(str,x), readdir(datapath))[1]
θ, options = reload_choice_model(datapath*prmfile)

str = Regex(ratnames[i]*".*rawdata.mat")
datafiles = filter(x->occursin(str,x),readdir(datapath))
print(datafiles)
data = load_choice_data(datapath*datafiles[1], dt = dt);


In [None]:
datafiles