In [1]:
using Distributed
addprocs(44);

In [2]:
using pulse_input_DDM

In [3]:
path = "/mnt/bucket/labs/brody/vdtang/P055_processed"
ratnames=["P055"]
sessids=[[17061800,18041700,16052600,16070500,16111700,18010300,18022600,17122200,
        17081200,17092600,17051900,18030500,17011800,18012200,17011500,17062500,
        17112700,18032700,16082200,18060800,18032500,17082900,17080500,17081000,
        17052300,17090800,17090900,18041500,18040500,17122000]];

In [4]:
#parameters of the latent model
pz = Dict("name" => vcat("σ_i","B", "λ", "σ_a","σ_s","ϕ","τ_ϕ"),
    "fit" => vcat(trues(7)),
    "initial" => vcat(2.,15.,-5.,100.,2.,0.2,0.005),
    "lb" => [eps(), 4., -5., eps(), eps(), eps(), eps()],
    "ub" => [10., 100, 5., 800., 40., 2., 10.])

#parameters for the choice observation
pd = Dict("name" => vcat("bias","lapse"), "fit" => trues(2), 
    "initial" => vcat(0.,0.5));

In [5]:
data = aggregate_choice_data(path, sessids, ratnames)

Dict{String,Any} with 9 entries:
  "rightbups"   => Array{Float64,1}[[0.320735, 0.58898], [0.04424, 0.11641, 0.1…
  "T"           => [1.4695, 1.507, 1.40101, 1.52, 1.49199, 1.4515, 1.401, 1.531…
  "ratID"       => ["P055", "P055", "P055", "P055", "P055", "P055", "P055", "P0…
  "leftbups"    => Array{Float64,1}[[0.002, 0.0592, 0.061615, 0.06239, 0.06786,…
  "stim_start"  => [541.881, 550.909, 573.805, 585.023, 594.216, 598.173, 601.5…
  "pokedR"      => Bool[false, true, false, false, true, false, true, false, tr…
  "correct_dir" => Bool[false, true, true, true, false, false, true, true, fals…
  "sessID"      => [17061800, 17061800, 17061800, 17061800, 17061800, 17061800,…
  "cpoke_end"   => [543.35, 552.416, 575.206, 586.543, 595.708, 599.625, 602.94…

In [6]:
#bin the click inputs at 10 ms
data = bin_clicks!(data;dt=1e-2);

In [10]:
@time pz, pd = load_and_optimize(data;f_tol=1e-9)

MethodError: MethodError: no method matching load_and_optimize(::Dict{String,Any}; f_tol=1.0e-9)
Closest candidates are:
  load_and_optimize(::Any; n, pz, show_trace, iterations) at /usr/people/briandd/.julia/packages/pulse_input_DDM/SIGgo/src/choice_model/wrapper_functions.jl:81 got unsupported keyword argument "f_tol"
  load_and_optimize(!Matched::String, !Matched::Any, !Matched::Any; n, dt, pz, show_trace, iterations) at /usr/people/briandd/.julia/packages/pulse_input_DDM/SIGgo/src/choice_model/wrapper_functions.jl:62 got unsupported keyword argument "f_tol"
  load_and_optimize(!Matched::String, !Matched::Any, !Matched::Any, !Matched::Any; n, dt, delay, pz, show_trace, iterations) at /usr/people/briandd/.julia/packages/pulse_input_DDM/SIGgo/src/neural_model/wrapper_functions.jl:67 got unsupported keyword argument "f_tol"
  ...

In [None]:
#compute the Hessian of the LL landscape, to compute confidence intervals on the parameters
pz, pd = compute_H_CI!(pz, pd, data);

In [None]:
using DataFrames
show(DataFrame(pz),allcols=true)
show(DataFrame(pd),allcols=true)

In [None]:
#simulate choices from the model, given the ML parameters
ML_data = deepcopy(data)
sample_choices_all_trials!(ML_data, pz["final"], pd["final"])

#compute the final click difference, which will dictate the correct choice
ΔLR = map((nT,L,R)->diffLR(nT,L,R,data["dt"])[end], data["nT"], data["leftbups"], data["rightbups"]);

In [None]:
#bin the choices from the generative parameters and from the ML parameters
import Pandas: qcut
import Statistics: mean

nbins = 15;
conds,qcut_bins = qcut(ΔLR, nbins, labels=false, retbins=true);
conds = conds .+ 1;

frac_choice_ML = [mean(ML_data["pokedR"][conds .== i]) for i in 1:nbins]
frac_choice_gen = [mean(data["pokedR"][conds .== i]) for i in 1:nbins];

In [None]:
#fit a GLM to the generative choices and the ML choices
using GLM
GLM_gen = glm(@formula(Y ~ X), DataFrame(X=ΔLR, Y = data["pokedR"]), Binomial(), LogitLink())
GLM_ML = glm(@formula(Y ~ X), DataFrame(X=ΔLR, Y = ML_data["pokedR"]), Binomial(), LogitLink());

In [None]:
using PyPlot

fig = figure(figsize=(6,6))
ax = subplot(111)

scatter(qcut_bins[1:end-1] + diff(qcut_bins)/2, frac_choice_gen, color="red", label="generative")
scatter(qcut_bins[1:end-1] + diff(qcut_bins)/2, frac_choice_ML, color="grey", label="ML")

plot(sort(ΔLR), sort(predict(GLM_gen)), color="red")
plot(sort(ΔLR), sort(predict(GLM_ML)), color="grey")

ylabel("% poked R")
xlabel(L"\Delta{LR}")
ax[:spines]["top"][:set_color]("none") 
ax[:spines]["right"][:set_color]("none")
legend()

In [22]:
sessidsF = [[17061801,18041701,16052601,16070501,16111701,18010301,
        18022601,17122201,17081201,17092601,17051901,18030501,
        17011801,18012201,17011501,17062501,17112701,18032701,16082201,
        18060801,18032501,17082901,17080501,17081001,
        17052301,17090801,17090901,18041501,18040501,17122001]];

In [None]:
#parameters of the latent model
pzF = Dict("name" => vcat("σ_i","B", "λ", "σ_a","σ_s","ϕ","τ_ϕ"),
    "fit" => vcat(trues(7)),
    "initial" => vcat(2.,15.,-5.,100.,2.,0.2,0.005),
    "lb" => [eps(), 4., -5., eps(), eps(), eps(), eps()],
    "ub" => [10., 100, 5., 800., 40., 2., 10.])

#parameters for the choice observation
pdF = Dict("name" => vcat("bias","lapse"), "fit" => trues(2), 
    "initial" => vcat(0.,0.5));

In [23]:
dataF = aggregate_choice_data(path, sessidsF, ratnames)

Dict{String,Any} with 13 entries:
  "binned_leftbups"  => Array{Int64,1}[[], [4, 9, 9, 11, 12, 14, 14, 16, 16, 16…
  "T"                => [1.401, 1.40099, 1.404, 1.4705, 1.437, 1.4005, 1.40101,…
  "stim_start"       => [821.41, 824.232, 827.015, 837.089, 839.737, 842.788, 8…
  "nT"               => [141, 141, 141, 148, 144, 141, 141, 141, 142, 160  …  1…
  "dt"               => 0.01
  "cpoke_end"        => [822.811, 825.633, 828.419, 838.56, 841.174, 844.188, 8…
  "binned_rightbups" => Array{Int64,1}[[8, 11, 20, 29, 32, 34, 38, 39, 40, 42  …
  "ratID"            => ["P055", "P055", "P055", "P055", "P055", "P055", "P055"…
  "leftbups"         => Array{Float64,1}[[], [0.037375, 0.08617, 0.08773, 0.101…
  "correct_dir"      => Bool[false, false, true, false, false, false, false, fa…
  "rightbups"        => Array{Float64,1}[[0.072645, 0.108225, 0.19472, 0.28751,…
  "pokedR"           => Bool[false, false, true, false, false, false, false, fa…
  "sessID"           => [17061801, 17061801, 1

In [None]:
#bin the click inputs at 10 ms
dataF = bin_clicks!(dataF;dt=1e-2);

In [None]:
@time pzF, pdF = load_and_optimize(dataF)

In [31]:
JLD2.@save "/usr/people/briandd/Projects/pulse_input_DDM.jl/data/results/marino.jld" pz pd pzF pdF