In [1]:
using Distributed
addprocs(44);

In [2]:
using pulse_input_DDM

In [3]:
path = "/mnt/bucket/labs/brody/vdtang/P055_processed"
ratnames=["P055"]
sessids=[[17061800,18041700,16052600,16070500,16111700,18010300,18022600,17122200,
        17081200,17092600,17051900,18030500,17011800,18012200,17011500,17062500,
        17112700,18032700,16082200,18060800,18032500,17082900,17080500,17081000,
        17052300,17090800,17090900,18041500,18040500,17122000]];

In [4]:
#parameters of the latent model
pz = Dict("name" => vcat("σ_i","B", "λ", "σ_a","σ_s","ϕ","τ_ϕ"),
    "fit" => vcat(trues(7)),
    "initial" => vcat(2.,15.,-5.,100.,2.,0.2,0.005),
    "lb" => [eps(), 4., -5., eps(), eps(), eps(), eps()],
    "ub" => [10., 100, 5., 800., 40., 2., 10.])

#parameters for the choice observation
pd = Dict("name" => vcat("bias","lapse"), "fit" => trues(2), 
    "initial" => vcat(0.,0.5));

In [5]:
data = aggregate_choice_data(path, sessids, ratnames);

In [6]:
#bin the click inputs at 10 ms
data = bin_clicks!(data;dt=1e-2);

In [7]:
#find the ML parameters with gradient descent
@time pz, pd, = optimize_model(pz, pd, data; f_tol=1e-9);

└ @ pulse_input_DDM /usr/people/briandd/.julia/packages/pulse_input_DDM/edA5O/src/wrapper_functions.jl:60


Iter     Function value   Gradient norm 
     0     5.255149e+03     1.156880e+03
     1     5.116243e+03     8.875782e+02
     2     4.900577e+03     2.138011e+02
     3     4.895523e+03     2.171708e+02
     4     4.852006e+03     1.618516e+02
     5     4.821839e+03     1.492736e+02
     6     4.796827e+03     5.548819e+01
     7     4.788745e+03     5.265261e+01
     8     4.782937e+03     6.176447e+01
     9     4.777379e+03     1.373893e+01
    10     4.777127e+03     2.525360e+01
    11     4.774601e+03     2.783452e+01
    12     4.772160e+03     3.570138e+01
    13     4.770569e+03     3.024035e+01
    14     4.769563e+03     2.880265e+01
    15     4.768599e+03     1.850554e+01
    16     4.767813e+03     1.441823e+01
    17     4.767020e+03     1.007484e+01
    18     4.766396e+03     6.265240e+00
    19     4.765841e+03     8.009592e+00
    20     4.765601e+03     1.283242e+01
    21     4.765267e+03     1.972953e+01
    22     4.764995e+03     1.503726e+01
    23     4.764

In [None]:
#compute the Hessian of the LL landscape, to compute confidence intervals on the parameters
pz, pd = compute_H_CI!(pz, pd, data);

In [None]:
using DataFrames
show(DataFrame(pz),allcols=true)
show(DataFrame(pd),allcols=true)

In [None]:
#simulate choices from the model, given the ML parameters
ML_data = deepcopy(data)
sample_choices_all_trials!(ML_data, pz["final"], pd["final"])

#compute the final click difference, which will dictate the correct choice
ΔLR = map((nT,L,R)->diffLR(nT,L,R,data["dt"])[end], data["nT"], data["leftbups"], data["rightbups"]);

In [None]:
#bin the choices from the generative parameters and from the ML parameters
import Pandas: qcut
import Statistics: mean

nbins = 15;
conds,qcut_bins = qcut(ΔLR, nbins, labels=false, retbins=true);
conds = conds .+ 1;

frac_choice_ML = [mean(ML_data["pokedR"][conds .== i]) for i in 1:nbins]
frac_choice_data = [mean(data["pokedR"][conds .== i]) for i in 1:nbins];

In [None]:
#fit a GLM to the generative choices and the ML choices
using GLM
GLM_data = glm(@formula(Y ~ X), DataFrame(X=ΔLR, Y = data["pokedR"]), Binomial(), LogitLink())
GLM_ML = glm(@formula(Y ~ X), DataFrame(X=ΔLR, Y = ML_data["pokedR"]), Binomial(), LogitLink());

In [None]:
using PyPlot

fig = figure(figsize=(6,6))
ax = subplot(111)

scatter(qcut_bins[1:end-1] + diff(qcut_bins)/2, frac_choice_data, color="red", label="generative")
scatter(qcut_bins[1:end-1] + diff(qcut_bins)/2, frac_choice_ML, color="grey", label="ML")

plot(sort(ΔLR), sort(predict(GLM_data)), color="red")
plot(sort(ΔLR), sort(predict(GLM_ML)), color="grey")

ylabel("% poked R")
xlabel(L"\Delta{LR}")
ax[:spines]["top"][:set_color]("none") 
ax[:spines]["right"][:set_color]("none")
legend()

## Proceed, as above, for any other datasets you wish to fit.

In [None]:
sessidsF = [[17061801,18041701,16052601,16070501,16111701,18010301,
        18022601,17122201,17081201,17092601,17051901,18030501,
        17011801,18012201,17011501,17062501,17112701,18032701,16082201,
        18060801,18032501,17082901,17080501,17081001,
        17052301,17090801,17090901,18041501,18040501,17122001]];

## save variables to disk

In [31]:
JLD2.@save "/usr/people/briandd/Projects/pulse_input_DDM.jl/data/results/marino.jld" pz pd