## Demo notebook for fitting the model of _Brunton et. al. (2013)_

This is some boilerplate, for loading packages necessary for running the notebook

In [1]:
using Distributed
addprocs(44);

In [19]:
@everywhere using pulse_input_DDM

In [20]:
using PyPlot, JLD2, DataFrames
using LinearAlgebra, Pandas, GLM

Define dictionaries with some generative parameters, a Boolean vector of which ones to fit, and initial values.

In [4]:
#parameters of the latent model
pz_gen = Dict("generative" => vcat(1.,13.,-0.5,10.,1.,0.4,0.02), 
    "name" => vcat("σ_i","B", "λ", "σ_a","σ_s","ϕ","τ_ϕ"),
    "fit" => vcat(trues(7)),
    "initial" => vcat(2.,15.,-5.,100.,2.,0.2,0.005),
    "lb" => [eps(), 4., -5., eps(), eps(), eps(), eps()],
    "ub" => [10., 100, 5., 800., 40., 2., 10.])

#parameters for the choice observation
pd_gen = Dict("generative" => vcat(0.1,0.1), 
    "name" => vcat("bias","lapse"), "fit" => trues(2), 
    "initial" => vcat(0.,0.5));

## simulate choices from a model, to check that generative parameters can be found

In [5]:
#define how many trials to generate data from and the time binning scale (in seconds)
ntrials,dt = Int(0.5e4), 1e-2

(5000, 0.01)

In [6]:
#generate some simulated clicks times and trial durations
data = pulse_input_DDM.sample_clicks(ntrials,dt);

In [21]:
#simulate choices from the model, given the generative parameters
pulse_input_DDM.sampled_dataset!(data, pz_gen["generative"], pd_gen["generative"]; num_reps=1, rng=4);

RemoteException: On worker 34:
UndefVarError: ##197#198 not defined
deserialize_datatype at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:1095
handle_deserialize at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:751
deserialize at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:711
deserialize_datatype at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:1119
handle_deserialize at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:751
deserialize at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:711
handle_deserialize at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:758
deserialize at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Serialization/src/Serialization.jl:711
deserialize_msg at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Distributed/src/messages.jl:99
#invokelatest#1 at ./essentials.jl:697 [inlined]
invokelatest at ./essentials.jl:696 [inlined]
message_handler_loop at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Distributed/src/process_messages.jl:160
process_tcp_streams at /home/conda/feedstock_root/build_artifacts/julia_1548684429855/work/usr/share/julia/stdlib/v1.0/Distributed/src/process_messages.jl:117
#105 at ./task.jl:259

In [8]:
#compute the likelihood of the data, given the generative parameters
compute_LL(pz_gen["generative"], pd_gen["generative"],data)

KeyError: KeyError: key "pokedR" not found

In [None]:
#find the ML parameters with gradient descent
@time pz_gen, pd_gen, = optimize_model(pz_gen, pd_gen, data)

In [None]:
#compute the Hessian of the LL landscape, to compute confidence intervals on the parameters
compute_H_CI(pz_gen, pd_gen, data, dt)

In [None]:
#identify which ML parameters have generative parameters within the CI 
pz_gen[:within_bounds] = (pz_gen[:CI_minus] .< pz_gen[:generative]) .& (pz_gen[:CI_plus] .> pz_gen[:generative])
pd_gen[:within_bounds] = (pd_gen[:CI_minus] .< pd_gen[:generative]) .& (pd_gen[:CI_plus] .> pd_gen[:generative]);

In [None]:
DataFrames.DataFrame(pz_gen)

In [None]:
pd_gen

In [None]:
#sample data from the model using the ML parameters
sampled_choices = pmap((T,L,R) -> sample_model(T,L,R,pz_gen[:final],pd_gen[:final]),
    data["T"],data["leftbups"],data["rightbups"])

#compute the final click difference, which will dictate the correct choice
ΔLR = map((nT,L,R)->diffLR(nT,L,R,dt),data["nT"],data["leftbups"],data["rightbups"])
ΔLR = map(x-> x[end], ΔLR);

In [None]:
#bin the real choices and the sampled choices
nbins = 15;
conds,qcut_bins = qcut(ΔLR,nbins,labels=false,retbins=true);
conds = conds .+ 1;

frac_choice_sampled = fill!(Vector{Float64}(undef,nbins),NaN)
frac_choice = fill!(Vector{Float64}(undef,nbins),NaN)
for i = 1:nbins
    frac_choice_sampled[i] = mean(sampled_choices[conds .== i])
    frac_choice[i] = mean(data["pokedR"][conds .== i])
end

In [None]:
#fit a GLM to the real choices and the simulated choices
GLM_dataframe = DataFrames.DataFrame(X= ΔLR, Y = data["pokedR"])
GLM_data = glm(@formula(Y ~ X), GLM_dataframe, Binomial(), LogitLink())
GLM_samplesframe = DataFrames.DataFrame(X= ΔLR, Y = sampled_choices)
GLM_samples = glm(@formula(Y ~ X), GLM_samplesframe, Binomial(), LogitLink())

In [None]:
fig = figure(figsize=(6,6))
ax = subplot(111)

scatter(qcut_bins[1:end-1] + diff(qcut_bins)/2, frac_choice, color="red", label="data")
scatter(qcut_bins[1:end-1] + diff(qcut_bins)/2, frac_choice_sampled, color="grey", label="sampled")

PyPlot.plot(sort(GLM_dataframe[:X]), sort(predict(GLM_data)), color="red")
PyPlot.plot(sort(GLM_samplesframe[:X]), sort(predict(GLM_samples)), color="grey")

ylabel("% poked R")
xlabel(L"\Delta{LR}")
ax[:spines]["top"][:set_color]("none") # Remove the top axis boundary
ax[:spines]["right"][:set_color]("none") # Remove the top axis boundary
legend()

## fit real data

In [None]:
path = "/mnt/bucket/labs/brody/vdtang/P055_processed"
ratnames=["P055"]
sessids=[[17061800,18041700,16052600,16070500,16111700,18010300,18022600,17122200,
        17081200,17092600,17051900,18030500,17011800,18012200,17011500,17062500,
        17112700,18032700,16082200,18060800,18032500,17082900,17080500,17081000,
        17052300,17090800,17090900,18041500,18040500,17122000]];

In [None]:
data = aggregate_choice_data(path, sessids, ratnames, dt)

## saving things

In [None]:
JLD2.@save "/usr/people/briandd/Projects/pulse_input_DDM.jl/data/results/marino.jld" pz pd pzF pdF